diff --git a/python/mlc_llm/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py index c83c39f36e..9d71d124a8 100644 --- a/python/mlc_llm/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -93,12 +93,49 @@ def _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBas total_size = params_bytes + temp_func_bytes logger.info( "%s: %.2f MB (Parameters: %.2f MB. Temporary buffer: %.2f MB)", - green("Total memory usage without KV cache:"), + green("Total memory usage without KV cache"), total_size / 1024 / 1024, params_bytes / 1024 / 1024, temp_func_bytes / 1024 / 1024, ) + # Compute KV cache size per token of context window. + if isinstance(config, ConfigBase): + config = asdict(config) + if ( + "head_dim" in config + and "num_hidden_layers" in config + and "num_key_value_heads" in config + and "quantization" in metadata + ): + quantization_type = metadata["quantization"] + dtype_bytes = None + if "f32" in quantization_type: + dtype_bytes = 4 + elif "bf16" in quantization_type: + dtype_bytes = 2 + elif "f16" in quantization_type: + dtype_bytes = 2 + # TODO: If support quantized KV in future, need to change this # pylint: disable=fixme + if dtype_bytes is not None: + bytes_per_token = ( + config["head_dim"] + * config["num_hidden_layers"] + * config["num_key_value_heads"] + * dtype_bytes + * 2 # 2 for key and value + ) + logger.info( + "%s: %.2f MB per token in the context window", + green("KV cache size"), + bytes_per_token / 1024 / 1024, + ) + logger.info( + "%s: %.2f MB", + green("Total memory usage with a 4K KV cache"), + (total_size + bytes_per_token * 4096) / 1024 / 1024, + ) + logger.info( "To reduce memory usage, " "tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size`"