What is Knowledge Distillation?
Knowledge distillation (KD) is a machine learning (ML) compression process that transfers knowledge from a large deep learning (DL) model to a smaller, more efficient model. In this context, knowledge refers to the patterns and behavior that the original model learned during training.
The goal of knowledge distillation is to reduce the memory footprint, compute requirements, and energy costs of a large model, so it can be used in a resource-constrained environment without significantly sacrificing performance.
Knowledge distillation is an effective way to improve the accuracy of a relatively small ML model. The process itself is sometimes referred to as teacher/student learning. The large model is the teacher, and the small model is the student.
How Knowledge Distillation Works
During the distillation process, the teacher model (a large, pre-trained foundation model) generates soft labels on the training data. The soft labels, which are essentially output probability distributions, are then used to train the student model.
- Hard labels tell what something is. Given an image of a table, for example, the teacher model’s output should be “table.”
- Soft labels are output probability distributions that state the model’s confidence scores across labels. For that same image, the teacher model’s output might be 90% “table”, 8% “desk”, and 2% “chair”.
Here is a very simple example of how knowledge distillation could be used to train a student model:
- Train the teacher model on a dataset.
- Generate soft labels from the teacher model for the same dataset.
- Train the student model on the same dataset with the soft labels.
- Fine-tune the student model on the dataset with hard labels. (Note: Steps three and four can be combined in some manner.)
- Evaluate the performance of the student model in terms of a loss function that quantifies how well the model’s predictions match desired outcomes.
If the performance level of the student model is acceptable, the student model can be deployed. If the performance level of the student model is unacceptable, the student model can be retrained with additional data or optimized by adjusting hyperparameters, learning rates, and/or the distillation temperature.
A higher temperature makes the probability distributions softer (less peaky), while a lower temperature makes it sharper (more peaky and closer to the hard labels).
Knowledge Distillation Optimization Techniques
- Attention Transfer
Attention transfer is a technique in which the student model is trained to mimic the attention maps generated by the teacher model. Attention maps highlight the important regions in an image or a sequence of words.
FitNets is a technique in which the student model is trained to match the intermediate representations of the teacher model. Intermediate representations are the hidden layers of the model that capture the underlying features of the input data.
- Similarity-Based Distillation
In this technique, the student model is trained to match the similarity matrix of the teacher model. The similarity matrix measures the pairwise similarities between multiple input samples.
- Hint-Based Distillation
Hint-based distillation is a technique in which the student model is trained to predict the difference between the outputs of the teacher model and the student model. This difference is called the hint.
- Cross-Entropy Distillation
The student model is trained using a loss function that combines the standard classification loss with a distillation loss that measures the difference between the teacher and student model’s output probabilities.
Knowledge distillation is an important technique for creating lightweight machine-learning models. These distilled models are especially beneficial for recommendation systems and IoT edge devices that have computational constraints.
- Classification: Assigning input data to one of several predefined categories.
- Natural Language Processing (NLP): Processing and analyzing large amounts of natural language data for tasks like sentiment analysis and named entity recognition.
- Object Detection: Identifying and classifying objects within images or videos.
- Speech Recognition: Converting spoken language into text.
- Machine Translation: Translating text or speech from one language to another.
Advantages and Disadvantages
One of the main advantages of knowledge distillation is that it enables the creation of smaller and faster models that perform well on Internet of Things (IoT) edge devices. It should be noted, however, that knowledge distillation often involves dealing with a trade-off between size and an acceptable level of accuracy.
Importance of Compression
One of the biggest challenges with developing business-to-consumer (B2C) applications that use artificial intelligence (AI) is that edge computing devices like mobile phones and tablets have limited storage and processing capabilities.
That leaves machine learning engineers with one option if they want to run a large model on an edge device: reduce the size of the model with compression techniques such as neural network pruning, quantization, low-rank factorization, and knowledge distillation.
- Pruning: This involves removing certain artificial neurons or weights from the network that contribute the least to the model’s performance. After pruning, the model’s size is reduced without a significant drop in accuracy.
- Quantization: This involves reducing the precision of a model’s weights (and sometimes activations) by using 16 or 8 bits to represent a weight, for example, instead of using 32 bits. This reduces the model’s size and can also speed up inference, especially on hardware that’s optimized for low-precision computations.
- Low-rank factorization: This involves approximating the weight matrices in a neural network with matrices of lower rank. The idea is that the information contained in the weight matrices can often be captured using fewer parameters.
- Knowledge distillation: This involves training a smaller lightweight machine-learning model to replicate the behavior of a larger, more resource-intensive model.
Knowledge Distillation vs. Transfer Learning
Knowledge distillation is sometimes referred to as a type of transfer learning, but the two concepts have different purposes.
The goal of knowledge distillation is to create a smaller machine-learning model that can solve the same task as a larger model.
In contrast, the goal of transfer learning is to reduce the time it takes to train a large model to solve a new task by using knowledge gained from a previously learned task.
Knowledge Distillation vs. Data Distillation
Data distillation and knowledge distillation are both compression processes, but they target different components. Data distillation focuses on the training data itself. Its goal is to obtain a smaller subset of the data that still represents the original large dataset.
In contrast, knowledge distillation focuses on reducing a model’s size without losing efficiency and accuracy.