From c709b6f0981db981271a488c15b846afd795258f Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Sat, 3 May 2025 21:37:35 -0400 Subject: [PATCH 1/2] [CLI] Report KV cache memory usage in mlc_llm compile --- python/mlc_llm/cli/model_metadata.py | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/python/mlc_llm/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py index c83c39f36e..8def1404a5 100644 --- a/python/mlc_llm/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -99,6 +99,42 @@ def _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBas temp_func_bytes / 1024 / 1024, ) + # Compute KV cache size per token of context window. + if isinstance(config, ConfigBase): + config = asdict(config) + if ( + "head_dim" in config + and "num_hidden_layers" in config + and "num_key_value_heads" in config + and "quantization" in metadata + ): + quantization_type = metadata["quantization"] + dtype_bytes = None + if "f32" in quantization_type: + dtype_bytes = 4 + elif "bf16" in quantization_type: + dtype_bytes = 2 + elif "f16" in quantization_type: + dtype_bytes = 2 + # TODO: In future, if support quantized KV cache, need to change this calculation + if dtype_bytes is not None: + bytes_per_token = ( + config["head_dim"] + * config["num_hidden_layers"] + * config["num_key_value_heads"] + * dtype_bytes + * 2 # 2 for key and value + ) + logger.info( + "%s: %.2f MB per token in the context window", + green("KV cache size:"), + bytes_per_token / 1024 / 1024, + ) + logger.info( + "Total memory usage with a 4K KV cache: %.2f MB", + (total_size + bytes_per_token * 4096) / 1024 / 1024, + ) + logger.info( "To reduce memory usage, " "tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size`" From 475854980e362f255685e322345a6092128cf433 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Sat, 3 May 2025 21:45:07 -0400 Subject: [PATCH 2/2] Trivial --- python/mlc_llm/cli/model_metadata.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/mlc_llm/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py index 8def1404a5..9d71d124a8 100644 --- a/python/mlc_llm/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -93,7 +93,7 @@ def _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBas total_size = params_bytes + temp_func_bytes logger.info( "%s: %.2f MB (Parameters: %.2f MB. Temporary buffer: %.2f MB)", - green("Total memory usage without KV cache:"), + green("Total memory usage without KV cache"), total_size / 1024 / 1024, params_bytes / 1024 / 1024, temp_func_bytes / 1024 / 1024, @@ -116,7 +116,7 @@ def _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBas dtype_bytes = 2 elif "f16" in quantization_type: dtype_bytes = 2 - # TODO: In future, if support quantized KV cache, need to change this calculation + # TODO: If support quantized KV in future, need to change this # pylint: disable=fixme if dtype_bytes is not None: bytes_per_token = ( config["head_dim"] @@ -127,11 +127,12 @@ def _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBas ) logger.info( "%s: %.2f MB per token in the context window", - green("KV cache size:"), + green("KV cache size"), bytes_per_token / 1024 / 1024, ) logger.info( - "Total memory usage with a 4K KV cache: %.2f MB", + "%s: %.2f MB", + green("Total memory usage with a 4K KV cache"), (total_size + bytes_per_token * 4096) / 1024 / 1024, )