Skip to content

Add Comprehensive QAT Training Framework for MLC-LLM #3258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cpp/serve/engine_actions/action_commons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

#include <tvm/runtime/nvtx.h>

#include <cinttypes>
#include <cstdio>

namespace mlc {
namespace llm {
namespace serve {
Expand Down Expand Up @@ -224,6 +227,22 @@ void ProcessFinishedRequestStateEntries(
rstate->metrics.finish_time_point = trequest_finish;
estate->metrics.RequestFinishUpdate(rstate->metrics);

// Print performance statistics to console
if (rstate->metrics.prefill_tokens > 0 && rstate->metrics.decode_tokens > 0) {
double prefill_time = rstate->metrics.GetPrefillTime();
double decode_time = rstate->metrics.GetDecodeTime();
if (prefill_time > 0 && decode_time > 0) {
double prefill_tps = rstate->metrics.prefill_tokens / prefill_time;
double decode_tps = rstate->metrics.decode_tokens / decode_time;
printf(
"[Request Completed] Prefill: %" PRId64 " tokens, Decode: %" PRId64 " tokens, "
"Prompt Cache: %" PRId64 " tokens, Prefill TPS: %.1f, Decode TPS: %.1f\n",
rstate->metrics.prefill_tokens, rstate->metrics.decode_tokens,
rstate->metrics.prompt_cache_tokens, prefill_tps, decode_tps);
fflush(stdout);
}
}

// always stream back usage in backend
callback_delta_outputs->push_back(RequestStreamOutput::Usage(
root_rsentry->request->id, rstate->metrics.AsUsageJSONStr(true)));
Expand Down
4 changes: 4 additions & 0 deletions cpp/serve/engine_actions/disagg_prepare_recv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ class DisaggPrepareReceiveActionObj : public BatchPrefillBaseActionObj {
// Update max prefill length
input->max_prefill_length =
std::min(input->max_prefill_length, rsentry->mstates[0]->GetInputLength());

// Set prompt cache tokens in metrics
rsentry->rstate->metrics.prompt_cache_tokens = result.prefilled_offset;

return result.prefilled_offset;
}
return 0;
Expand Down
4 changes: 4 additions & 0 deletions cpp/serve/engine_actions/disagg_remote_send.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,10 @@ class DisaggRemoteSendActionObj : public BatchPrefillBaseActionObj {
// Update max prefill length
input->max_prefill_length =
std::min(input->max_prefill_length, rsentry->mstates[0]->GetInputLength());

// Set prompt cache tokens in metrics
rsentry->rstate->metrics.prompt_cache_tokens = result.prefilled_offset;

return result.prefilled_offset;
}
return 0;
Expand Down
4 changes: 4 additions & 0 deletions cpp/serve/engine_actions/eagle_new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,10 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
// Update max prefill length
input->max_prefill_length =
std::min(input->max_prefill_length, rsentry->mstates[0]->GetInputLength());

// Set prompt cache tokens in metrics
rsentry->rstate->metrics.prompt_cache_tokens = result.prefilled_offset;

return result.prefilled_offset - 1;
}
return 0;
Expand Down
4 changes: 4 additions & 0 deletions cpp/serve/engine_actions/new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ class NewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
// Update max prefill length
input->max_prefill_length =
std::min(input->max_prefill_length, rsentry->mstates[0]->GetInputLength());

// Set prompt cache tokens in metrics
rsentry->rstate->metrics.prompt_cache_tokens = result.prefilled_offset;

return result.prefilled_offset;
}
return 0;
Expand Down
3 changes: 3 additions & 0 deletions cpp/serve/metrics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ picojson::object RequestMetrics::AsJSON() const {
metrics["prefill_tokens"] = picojson::value(prefill_tokens);
metrics["decode_tokens"] = picojson::value(decode_tokens);
metrics["jump_forward_tokens"] = picojson::value(jump_forward_tokens);
metrics["prompt_cache_tokens"] = picojson::value(prompt_cache_tokens);

if (prefill_tokens != 0) {
metrics["prefill_tokens_per_s"] = picojson::value(prefill_tokens / this->GetPrefillTime());
Expand Down Expand Up @@ -113,6 +114,7 @@ picojson::object EngineMetrics::AsJSON() const {
metrics["prefill_tokens_sum"] = picojson::value(prefill_tokens_sum);
metrics["decode_tokens_sum"] = picojson::value(decode_tokens_sum);
metrics["jump_forward_tokens_sum"] = picojson::value(jump_forward_tokens_sum);
metrics["prompt_cache_tokens_sum"] = picojson::value(prompt_cache_tokens_sum);

if (prefill_tokens_sum != 0) {
metrics["prefill_tokens_per_s"] = picojson::value(prefill_tokens_sum / engine_prefill_time_sum);
Expand Down Expand Up @@ -170,6 +172,7 @@ void EngineMetrics::Reset() {
prefill_tokens_sum = 0;
decode_tokens_sum = 0;
jump_forward_tokens_sum = 0;
prompt_cache_tokens_sum = 0;
last_finished_request.Reset();
spec_decode.Reset();
decode_time_by_batch_size.clear();
Expand Down
6 changes: 6 additions & 0 deletions cpp/serve/metrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ struct RequestMetrics {
int64_t decode_tokens = 0;
/*! \brief The number of tokens predicted by jump-forward decoding. */
int64_t jump_forward_tokens = 0;
/*! \brief The number of tokens retrieved from prompt cache. */
int64_t prompt_cache_tokens = 0;

/*! \brief The time of adding the request to engine. */
std::chrono::high_resolution_clock::time_point add_time_point;
Expand Down Expand Up @@ -149,6 +151,7 @@ struct RequestMetrics {
this->prompt_tokens = 0;
this->prefill_tokens = 0;
this->completion_tokens = 0;
this->prompt_cache_tokens = 0;
}
/*!
* \brief Return the request metrics in JSON.
Expand Down Expand Up @@ -182,6 +185,8 @@ struct EngineMetrics {
int64_t decode_tokens_sum = 0;
/*! \brief The total number of tokens predicted by jump-forward decoding. */
int64_t jump_forward_tokens_sum = 0;
/*! \brief The total number of tokens retrieved from prompt cache. */
int64_t prompt_cache_tokens_sum = 0;
/*! \brief metrics from last finished request. */
RequestMetrics last_finished_request;
/*! \brief speculative decoding metrics */
Expand Down Expand Up @@ -240,6 +245,7 @@ struct EngineMetrics {
completion_tokens_sum += request_metrics.completion_tokens;
decode_tokens_sum += request_metrics.decode_tokens;
jump_forward_tokens_sum += request_metrics.jump_forward_tokens;
prompt_cache_tokens_sum += request_metrics.prompt_cache_tokens;
last_finished_request = request_metrics;
}
/*!
Expand Down
Empty file added enable_debug_logging.py
Empty file.
11 changes: 11 additions & 0 deletions python/mlc_llm/interface/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from mlc_llm.serve.server import ServerContext
from mlc_llm.support import logging
import logging as std_logging

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,6 +54,16 @@ def serve(
allow_headers: Any,
): # pylint: disable=too-many-arguments, too-many-locals
"""Serve the model with the specified configuration."""

# Enable DEBUG logging if enable_debug is True
if enable_debug:
# Set logging level to DEBUG for detailed output including prompt context
std_logging.getLogger("mlc_llm").setLevel(std_logging.DEBUG)
std_logging.getLogger("mlc_llm.serve").setLevel(std_logging.DEBUG)
std_logging.getLogger("mlc_llm.serve.engine_base").setLevel(std_logging.DEBUG)
std_logging.getLogger("mlc_llm.serve.engine").setLevel(std_logging.DEBUG)
logger.info("DEBUG logging enabled for MLC-LLM serve mode")

# Create engine and start the background loop
async_engine = engine.AsyncMLCEngine(
model=model,
Expand Down
6 changes: 6 additions & 0 deletions python/mlc_llm/serve/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,10 @@ async def _handle_chat_completion(
)
# prompt length is not used
_ = prompt_length

# Log the final prompt for debugging
logger.debug("Request %s: Final prompt before processing: %s", request_id, prompts)

finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)]
self.state.record_event(request_id, event="invoke generate")
try:
Expand Down Expand Up @@ -1778,6 +1782,8 @@ def _handle_chat_completion(
self.max_input_sequence_length,
self.conv_template.model_copy(deep=True),
)
# prompt length is not used - this variable is kept for API compatibility
# and potential future use in the synchronous engine
_ = prompt_length

finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)]
Expand Down
32 changes: 27 additions & 5 deletions python/mlc_llm/serve/engine_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,17 +752,39 @@ def process_chat_completion_request( # pylint: disable=too-many-arguments
# - Get the prompt from template, and encode to token ids.
# - Check prompt length
engine_state.record_event(request_id, event="start tokenization")
prompts = engine_utils.process_prompts( # type: ignore
conv_template.as_prompt(model_config), f_tokenize
)

# Generate prompt string from chat template and log it for debugging
prompt_str = conv_template.as_prompt(model_config)
logger.debug("Request %s: Chat template applied, prompt string: %s", request_id, prompt_str)
logger.debug("Request %s: Chat template name: %s", request_id, conv_template.name)

# Log message details for better debugging
logger.debug("Request %s: Processing %d messages", request_id, len(request.messages))
for i, message in enumerate(request.messages):
role = message.role
content = str(message.content) if message.content else "(empty)"
content_preview = content[:100] + "..." if len(content) > 100 else content
logger.debug("Request %s: Message %d - Role: %s, Content: %s",
request_id, i, role, content_preview)

# Tokenize the prompt
prompts = engine_utils.process_prompts(prompt_str, f_tokenize) # type: ignore
logger.debug("Request %s: Tokenized prompt length: %d tokens",
request_id, len(prompts[0]) if prompts and isinstance(prompts[0], list) else 0)

engine_state.record_event(request_id, event="finish tokenization")

if conv_template.system_prefix_token_ids is not None:
logger.debug("Request %s: Adding system prefix token ids, count: %d",
request_id, len(conv_template.system_prefix_token_ids))
if isinstance(prompts[0], list):
prompts[0] = conv_template.system_prefix_token_ids + prompts[0]
else:
prompts.insert(0, conv_template.system_prefix_token_ids)

prompt_length = engine_utils.check_and_get_prompts_length(prompts, max_input_sequence_length)
logger.debug("Request %s: Final prompt length: %d tokens (max allowed: %d)",
request_id, prompt_length, max_input_sequence_length)

# Process generation config. Create request id.
generation_cfg = engine_utils.get_generation_config(
Expand Down Expand Up @@ -1252,8 +1274,8 @@ def wrap_completion_response( # pylint: disable=too-many-arguments
model: str,
output_texts: List[str],
finish_reasons: List[str],
logprob_results: List[Optional[openai_api_protocol.CompletionLogProbs]],
usage: openai_api_protocol.CompletionUsage,
logprob_results: Optional[List[Optional[openai_api_protocol.CompletionLogProbs]]] = None,
) -> openai_api_protocol.CompletionResponse:
"""Wrap the non-streaming completion results to CompletionResponse instance."""
return openai_api_protocol.CompletionResponse(
Expand All @@ -1263,7 +1285,7 @@ def wrap_completion_response( # pylint: disable=too-many-arguments
index=i,
finish_reason=finish_reason,
text=output_text,
logprobs=logprob_results[i],
logprobs=logprob_results[i] if logprob_results is not None else None,
)
for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons))
],
Expand Down
15 changes: 15 additions & 0 deletions python/mlc_llm/serve/entrypoints/openai_entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
)
from mlc_llm.serve import engine_base, engine_utils
from mlc_llm.serve.server import ServerContext
from mlc_llm.support import logging

logger = logging.getLogger(__name__)

app = fastapi.APIRouter()
################ v1/models ################
Expand Down Expand Up @@ -140,6 +143,18 @@ async def request_chat_completion(
"""OpenAI-compatible chat completion API.
API reference: https://platform.openai.com/docs/api-reference/chat
"""
# Log incoming request for debugging
logger.debug(
"Incoming chat completion request: model=%s, stream=%s, max_tokens=%s, messages=%d",
request.model, request.stream, request.max_tokens, len(request.messages))

# Log message details
for i, message in enumerate(request.messages):
content_preview = (str(message.content)[:100] + "..."
if message.content and len(str(message.content)) > 100
else str(message.content))
logger.debug("Request message %d: role=%s, content=%s", i, message.role, content_preview)

# - Check the requested model.
server_context: ServerContext = ServerContext.current()
request_final_usage_include_extra = server_context.enable_debug
Expand Down
68 changes: 68 additions & 0 deletions qat_training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# QAT Training for MLC-LLM

This directory contains scripts and utilities for Quantization Aware Training (QAT) that are compatible with MLC-LLM's q4f16_1 format.

## Overview

- **Base Model**: Llama3.2-1B (after SFT training)
- **Training Data**: ShareGPT format, distributed across multiple files
- **Target Quantization**: q4f16_1 format for MLC-LLM
- **Output**: QAT-trained model ready for MLC-LLM conversion

## Directory Structure

```
qat_training/
├── README.md # This file
├── config/
│ ├── training_config.py # Training configuration
│ └── model_config.py # Model-specific settings
├── data/
│ ├── data_loader.py # ShareGPT data loading utilities
│ ├── data_processor.py # Data preprocessing and sampling
│ └── data_sampler.py # Smart sampling from large datasets
├── training/
│ ├── qat_trainer.py # Main QAT training script
│ ├── qat_model.py # QAT model wrapper
│ └── metrics_logger.py # Training progress and metrics logging
├── conversion/
│ ├── weight_converter.py # Convert QAT weights to q4f16_1 format
│ └── mlc_formatter.py # Format weights for MLC-LLM
├── scripts/
│ ├── train_qat.py # Main training entry point
│ ├── convert_to_mlc.py # Conversion script
│ └── validate_model.py # Model validation
└── examples/
├── sample_config.yaml # Example configuration
└── run_training.sh # Example training script
```

## Quick Start

1. **Prepare your training data**:
```bash
python scripts/prepare_data.py --input_dir /path/to/sharegpt/files --output_dir ./data/processed --sample_count 30000
```

2. **Start QAT training**:
```bash
python scripts/train_qat.py --config examples/sample_config.yaml --model_path /path/to/your/sft/model
```

3. **Convert to MLC format**:
```bash
python scripts/convert_to_mlc.py --qat_model ./outputs/qat_trained --output_dir ./outputs/mlc_ready
```

4. **Use with MLC-LLM**:
```bash
mlc_llm convert_weight ./outputs/mlc_ready --quantization q4f16_1 --output ./final_model
```

## Features

- **Multi-file ShareGPT support**: Automatically loads and processes ShareGPT data from multiple files
- **Smart data sampling**: Intelligent sampling strategies to select representative data from large datasets
- **Progress monitoring**: Comprehensive logging of training progress, loss, and metrics
- **MLC-LLM compatible**: Direct conversion to q4f16_1 format without external dependencies
- **Llama3.2-1B optimized**: Pre-configured for Llama3.2-1B architecture
1 change: 1 addition & 0 deletions qat_training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# QAT Training for MLC-LLM
1 change: 1 addition & 0 deletions qat_training/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Configuration modules
Loading