by Vishal Yadav, Nikhil Pandey
Introduction
Large Language Models (LLMs) have transformed the landscape of natural language processing (NLP) with their ability to understand and generate human-like text. However, their size and complexity often pose challenges in terms of deployment, speed, and cost. Usually for specialized niche tasks, we end up deploying the best available model even though we don’t utilize all its capabilities. This is where distillation comes in, offering a method to create (fine-tune) smaller, customized, more efficient models, while retaining much of the performance of a significantly larger state-of-the-art model.
What is distillation?
Distillation is a technique designed to transfer knowledge of a large pre-trained model (the "teacher") into a smaller model (the "student"), enabling the student model to achieve comparable performance to the teacher model. This technique allows users to leverage the high quality of larger LLMs, while reducing inference costs in a production environment, thanks to the smaller student model.
How distillation works?
In distillation, knowledge can be transferred from teacher to student model in several ways. Here, we specifically discuss response-based, offline distillation, where the student model learns to mimic the output (only predictions) of the teacher model, and the teacher model is not trained during distillation.
Teacher Model: A large, high-capacity teacher model that is already pre-trained on massive datasets. This model has learnt rich representations and complex patterns from the data which allows it to generalize well even on unseen tasks.
Knowledge Extraction: The teacher model generates outputs based on given inputs, which are then used as training data for the student model. This involves not just mimicking outputs but also understanding the underlying reasoning processes.
Student Model Training: A smaller student model is trained using the extracted knowledge as a guide. The student model learns to mimic the teacher model's behavior and predictions on specific tasks.
Advantages
Reduced Size: The resulting student model is significantly smaller, making it easier to deploy in resource-constrained environments.
Lower Cost: Running smaller models incurs lower operational costs while maintaining competitive performance levels.
Task-Specific Optimization: Distillation can be tailored for specific applications, enhancing efficiency and accuracy.
Performance: Smaller models exhibit significantly lower latency compared to larger models, which in turn boosts the throughput of the deployment.
Customization: Distillation allows users to select desirable traits from multiple larger models and transfer them to smaller models.
Personalization: Personality traits can be incorporated into the model, enabling it to respond with relevant answers when queried about its personality.
Synthetic Data Generation: At scale data generation can be done either only for labels or from scratch using just seed/meta data.
Generalization: Distillation can help student models generalize better by learning from the teacher model's knowledge and avoiding overfitting.
Improved Multilingual Capabilities: The multilingual performance of smaller models can be significantly enhanced with the help of teacher models making them suitable for global applications.
Distillation in Azure AI Foundry
Distillation as a Service is now supported on Azure allowing a variety of task types and more to be added soon. Following tasks are supported.
- Summarization: Given a document (article) to summarize, generate an entity-dense summary of the document.
- Conversational Assistant: Generate AI assistant responses on single-turn and multi-turn conversational datasets. To generate each response, the available chat history and the current user prompt are utilized.
- Natural Language Understanding (NLU)
o MATH: Generate numeric answers to math problems.
o Natural Language Inference (NLI): Given premise and hypothesis, determine if premise entails the hypothesis, or contradicts the hypothesis, or is neutral i.e. neither entails not contradicts the hypothesis.
o Multiple-Choice Question Answering: Given question and answer choices, determine the correct answer choice.
Distillation Process
Overview of the two-step distillation process: (1) Generate synthetic data using a task-specific, elaborate prompt (2) Train (and infer from) the student model using a shorter prompt
(Figure source: https://arxiv.org/pdf/2410.18588)
The distillation process involves two main steps: generate high quality synthetic data (labels) using the teacher model, followed by instruction-based finetuning of the student model.
Data Generation
High-quality data generation is crucial for the student model's performance. Azure provides a proprietary library of advanced prompts, to generate high-quality synthetic data for all supported tasks, utilizing techniques such as Chain of Thought (CoT) or Chain of Density (CoD), and other best practices. This option can be enabled by passing the `enable_chain_of_thought` parameter while invoking the distillation pipeline, ensuring reasoning-based answers and consequently high-quality data for distillation.
Instruction Fine-Tuning
The next step is to fine-tune the smaller model using the task-specific generated data. This involves using a concise, task-specific prompt and training with the input and generated output (excluding reasoning steps). These innovations ensure significant performance gains for a given task while minimizing the cost (number of tokens) for the user. When using user-provided prompts, the same prompt is applied in both data generation and fine-tuning.
Distillation Code Snippet
Distillation is supported by the Azure SDK and CLI. Support for this was added in version 1.22.0 of azure-ai-ml. Ensure that the azure-ai-ml package is >= 1.22.0 before using the code snippet below.
Model Offerings
Teacher Models
Currently Meta Llama 3.1 405B Instruct is supported as the teacher model for distillation.
Student Models
Currently Meta Llama 3.1 8B Instruct is supported as the student model for distillation. Soon all Microsoft’s Phi 3 and 3.5 Instruct series models will also be available for distillation. The following table demonstrates our current and upcoming student model offerings.
Student Model |
Region |
Availability |
Meta Llama 3.1 8B Instruct |
West US 3 |
Available |
Phi 3/3.5 Instruct |
East US 2 |
Coming Soon |
At the time of this writing, fine-tuning of Meta Llama 3.1 Instruct series of models, and deployment of such fine-tuned models, is only available in West US 3 region. Whereas fine-tuning of Microsoft’s Phi 3 Instruct series of models, and deployment of such fine-tuned models, is only available in East US 2 region. Ensure your AI Foundry project is setup in the appropriate region for your selected student model.
Notebooks
Distilling Large Language Models for NLI Tasks: A Practical Guide
Notebook - Distillation with Large Language Models
This notebook provides a comprehensive guide on how to distil a large teacher model into a smaller student model, specifically for Natural Language Inference (NLI) tasks. It uses the Meta Llama 3.1 405B Instruct as the teacher and the Meta Llama 3.1 8B Instruct as the student model.
Key Highlights
- Teacher and Student Models: The process uses Meta Llama 3.1 405B Instruct as the teacher model and Meta Llama 3.1 8B Instruct as the student model.
- Prerequisites: Ensure you have subscribed to the required models and set up an AI Foundry project in the West US 3 region for distillation of a Meta Llama 3.1 8B Instruct student model.
- SDK Installation: Install necessary SDKs such as azure-ai-ml, azure-identity, and mlflow.
- Dataset Preparation: Use the ConjNLI dataset from Hugging Face for training and validation.
- Distillation Job: Configure and run the distillation job to transfer knowledge from the teacher to the student model.
- Deployment: Optionally, deploy the distilled model to a serverless endpoint and perform sample inferences.
This notebook simplifies the complex task of model distillation, making it accessible even to those new to NLP and model training.
Results
Using the ConjNLI dataset and Chain-Of-Thought (CoT) distillation, we obtain the following accuracy (%) metrics.
Dataset |
Student Model |
Teacher (405B) with CoT Prompting |
Student with CoT Prompting |
Student Distilled on CoT-prompted Teacher Output |
ConjNLI (dev) |
Meta Llama 3.1 8B Instruct |
69.98 |
52.81 |
63.88 |
ConjNLI (dev) |
Phi 3 Mini 128k Instruct |
69.98 |
43.98 |
57.78 |
Distillation with the Meta Llama 3.1 8B Instruct and Phi 3 Mini 128k Instruct student models provides approximately 21% and 31% improvement respectively over directly prompting the student model using CoT prompt. For detailed results on other datasets and tasks, we refer the user to check the published results in our knowledge distillation paper.
Conclusion
Distillation represents a significant step forward in development and deployment of LLM/SLM at scale. By transferring the knowledge from a large pre-trained model (teacher) to a smaller, more efficient model (student), distillation offers a practical solution to the challenges of deploying large models, such as high costs and complexity. This technique not only reduces model size and operational costs but also enhances the performance of student models for specific tasks. The support for distillation on Azure AI Foundry further simplifies the process, making it accessible for various applications, such as summarization, conversational assistance, and natural language understanding tasks. Furthermore, the detailed, hands-on example notebooks provided in Azure Github can help facilitate easier adoption.
In summary, distillation not only bridges the gap between generalist understanding and specialized application but also makes the way for a more sustainable and practical approach to leveraging LLMs in real-world scenarios.
Updated Dec 06, 2024
Version 1.0nikhilpandey
Microsoft
Joined December 03, 2024
AI - AI Platform Blog
Follow this blog board to get notified when there's new activity