-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add Comprehensive QAT Training Framework for MLC-LLM #3258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: alohachen <126397459+alohachen@users.noreply.github.com>
…6e1-dba0d7748779 Add performance statistics display to mlc_llm serve command
Co-authored-by: alohachen <126397459+alohachen@users.noreply.github.com>
…5ee-ce1897090066 Add comprehensive prompt logging for debugging in serve engine
- Add prompt_cache_tokens field to RequestMetrics and EngineMetrics - Track and display cache hit tokens in serve completion logs - Update all prefix cache matching functions to record cache statistics - Include prompt cache tokens in JSON metrics output - Fix format warnings for int64_t printf on different platforms Output format now includes: Prompt Cache: X tokens
- Add complete quantization aware training (QAT) framework - Support for ShareGPT format data with multi-file loading - Smart data sampling strategies (balanced, diverse, quality-based) - Optimized for Llama3.2-1B models with LoRA fine-tuning - Comprehensive training metrics and progress logging with plots - Direct conversion to MLC-LLM q4f16_1 quantization format - Ready-to-use scripts and configuration examples Features: - Multi-file ShareGPT data loader with validation - Intelligent data sampling from large datasets - 4-bit quantization aware training using BitsAndBytes - LoRA adaptation for memory-efficient training - Real-time training monitoring with matplotlib plots - Automatic weight conversion to MLC-LLM format - Comprehensive error handling and logging Usage: qat_training/examples/run_training.sh Output format: q4f16_1 compatible with MLC-LLM inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a comprehensive Quantization Aware Training (QAT) framework for MLC-LLM models. Key changes include a full training pipeline with support for multi-file ShareGPT data, smart data sampling strategies, extensive configuration setup for both training and model conversion, and integration of real‐time metrics logging and conversion to the MLC-LLM q4f16_1 format.
Reviewed Changes
Copilot reviewed 32 out of 34 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
qat_training/training/qat_trainer.py | Implements model/tokenizer setup, trainer creation, training execution, evaluation, and export for MLC conversion. |
qat_training/scripts/*.py | Adds scripts for running training, validation, and conversion workflows. |
qat_training/data/* | Provides multi-file data loading, sampling, and processing for ShareGPT formatted data. |
qat_training/conversion/weight_converter.py | Contains weight extraction, group quantization, packing, and saving routines for conversion into MLC-LLM format. |
qat_training/config/* | Introduces configuration definitions and example configuration files. |
cpp/serve/* | Updates to C++ serving modules to incorporate new metrics regarding prompt cache tokens. |
python/mlc_llm/serve/* | Enhancements to logging and prompt processing within the serving engine for better traceability in asynchronous operations. |
"compute_dtype": "float16" | ||
}, | ||
"converted_from": "qat_training", | ||
"conversion_timestamp": torch.datetime.now().isoformat(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears that 'torch.datetime.now()' is used to generate a timestamp, but PyTorch does not provide a datetime module. Replace it with the standard Python datetime module (e.g., import datetime and use datetime.datetime.now().isoformat()).
"conversion_timestamp": torch.datetime.now().isoformat(), | |
"conversion_timestamp": datetime.datetime.now().isoformat(), |
Copilot uses AI. Check for mistakes.
Summary
This PR adds a comprehensive Quantization Aware Training (QAT) framework specifically designed for MLC-LLM compatibility. The framework enables training quantized models that can be directly converted to MLC-LLM's q4f16_1 format for efficient inference.
Key Features
Architecture
Usage
Quick Start
Advanced Usage
Technical Details
Benefits
Testing
The framework includes validation scripts to ensure model correctness:
Test Plan
Related Issues
This addresses the need for better quantization methods in MLC-LLM, providing an alternative to the unstable AWQ implementation with QAT-trained models that can achieve better accuracy while maintaining inference efficiency.