Skip to content

[Bug] [TPU]: OOMing on Llama-8B on new vllm nightly docker #19490

Open
@BabyChouSr

Description

@BabyChouSr

Your current environment

The output of python collect_env.py
root@t1v-n-82109f0e-w-0:/opt# python collect_env.py 
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
INFO 06-11 13:54:20 [__init__.py:244] Automatically detected platform tpu.
INFO 06-11 13:54:20 [tpu.py:215] tpu_commons not found, using vLLM's TpuPlatform
Collecting environment information...
==============================
        System Info
==============================
OS                           : Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version                  : (Debian 10.2.1-6) 10.2.1 20210110
Clang version                : Could not collect
CMake version                : version 4.0.2
Libc version                 : glibc-2.31

==============================
       PyTorch Info
==============================
PyTorch version              : 2.8.0.dev20250605+cpu
Is debug build               : False
CUDA used to build PyTorch   : None
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.10.16 (main, Jan 14 2025, 05:27:07) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform              : Linux-6.8.0-1015-gcp-x86_64-with-glibc2.31

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : False
CUDA runtime version         : No CUDA
CUDA_MODULE_LOADING set to   : N/A
GPU models and configuration : No CUDA
Nvidia driver version        : No CUDA
cuDNN version                : No CUDA
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Byte Order:                           Little Endian
Address sizes:                        52 bits physical, 57 bits virtual
CPU(s):                               180
On-line CPU(s) list:                  0-179
Thread(s) per core:                   1
Core(s) per socket:                   90
Socket(s):                            2
NUMA node(s):                         2
Vendor ID:                            AuthenticAMD
CPU family:                           25
Model:                                17
Model name:                           AMD EPYC 9B14
Stepping:                             1
CPU MHz:                              2599.998
BogoMIPS:                             5199.99
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            5.6 MiB
L1i cache:                            5.6 MiB
L2 cache:                             180 MiB
L3 cache:                             704 MiB
NUMA node0 CPU(s):                    0-89
NUMA node1 CPU(s):                    90-179
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; Safe RET
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm

==============================
Versions of relevant libraries
==============================
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.2.2
[pip3] pyzmq==26.4.0
[pip3] torch==2.8.0.dev20250605+cpu
[pip3] torch-xla==2.8.0+git402612d
[pip3] torchvision==0.23.0.dev20250605+cpu
[pip3] transformers==4.48.3
[pip3] triton==3.3.1
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
Neuron SDK Version           : N/A
vLLM Version                 : 0.9.1rc2.dev5+g6cd4ae8ac (git sha: 6cd4ae8ac)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
  Could not collect

==============================
     Environment Variables
==============================
VLLM_TARGET_DEVICE=tpu
LD_LIBRARY_PATH=:/usr/lib/x86_64-linux-gnu/:/home/ray/anaconda3/lib
VLLM_USE_V1=1
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_root

OOMing on serving llama-8B-Instruct:

root@t1v-n-82109f0e-w-0:/opt# vllm serve meta-llama/Llama-3.1-8B-Instruct --max-model-len 8192
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
INFO 06-11 13:55:34 [__init__.py:244] Automatically detected platform tpu.
INFO 06-11 13:55:34 [tpu.py:215] tpu_commons not found, using vLLM's TpuPlatform
INFO 06-11 13:55:37 [config.py:1980] Disabled the custom all-reduce kernel because it is not supported on current platform.
INFO 06-11 13:55:38 [config.py:1980] Disabled the custom all-reduce kernel because it is not supported on current platform.
INFO 06-11 13:55:38 [config.py:1980] Disabled the custom all-reduce kernel because it is not supported on current platform.
INFO 06-11 13:55:39 [config.py:1980] Disabled the custom all-reduce kernel because it is not supported on current platform.
INFO 06-11 13:55:39 [api_server.py:1287] vLLM API server version 0.9.1rc2.dev5+g6cd4ae8ac
INFO 06-11 13:55:39 [config.py:1980] Disabled the custom all-reduce kernel because it is not supported on current platform.
INFO 06-11 13:55:39 [cli_args.py:309] non-default args: {'model': 'meta-llama/Llama-3.1-8B-Instruct', 'max_model_len': 8192}
INFO 06-11 13:55:48 [config.py:823] This model supports multiple tasks: {'reward', 'embed', 'classify', 'generate', 'score'}. Defaulting to 'generate'.
INFO 06-11 13:55:48 [config.py:1980] Disabled the custom all-reduce kernel because it is not supported on current platform.
INFO 06-11 13:55:48 [config.py:2195] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 06-11 13:55:48 [tpu.py:105] [TPU] Forcing DYNAMO_ONCE compilation level
WARNING 06-11 13:55:50 [env_override.py:17] NCCL_CUMEM_ENABLE is set to 0, skipping override. This may increase memory overhead with cudagraph+allreduce: https://github.com/NVIDIA/nccl/issues/1234
INFO 06-11 13:55:54 [__init__.py:244] Automatically detected platform tpu.
INFO 06-11 13:55:54 [tpu.py:215] tpu_commons not found, using vLLM's TpuPlatform
INFO 06-11 13:55:56 [core.py:455] Waiting for init message from front-end.
INFO 06-11 13:55:56 [tpu.py:105] [TPU] Forcing DYNAMO_ONCE compilation level
INFO 06-11 13:55:56 [core.py:70] Initializing a V1 LLM engine (v0.9.1rc2.dev5+g6cd4ae8ac) with config: model='meta-llama/Llama-3.1-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=None, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=meta-llama/Llama-3.1-8B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level":2,"debug_dump_path":"","cache_dir":"","backend":"openxla","custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}
INFO 06-11 13:55:56 [tpu_worker.py:298] tpu_commons not found, using vLLM's TPUWorker.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 06-11 13:55:56 [parallel_state.py:1065] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
WARNING 06-11 13:56:29 [tpu.py:178] Pin memory is not supported on TPU.
INFO 06-11 13:56:29 [tpu_model_runner.py:1617] Using exponential token paddings:
INFO 06-11 13:56:29 [tpu_model_runner.py:1619]     16
INFO 06-11 13:56:29 [tpu_model_runner.py:1619]     32
INFO 06-11 13:56:29 [tpu_model_runner.py:1619]     64
INFO 06-11 13:56:29 [tpu_model_runner.py:1619]     128
INFO 06-11 13:56:29 [tpu_model_runner.py:1619]     256
INFO 06-11 13:56:29 [tpu_model_runner.py:1619]     512
INFO 06-11 13:56:29 [tpu_model_runner.py:1619]     1024
INFO 06-11 13:56:29 [tpu_model_runner.py:1619]     2048
INFO 06-11 13:56:29 [tpu_model_runner.py:1583] Preparing request paddings:
INFO 06-11 13:56:29 [tpu_model_runner.py:1590]     8
INFO 06-11 13:56:29 [tpu_model_runner.py:1590]     16
INFO 06-11 13:56:29 [tpu_model_runner.py:1590]     32
INFO 06-11 13:56:29 [tpu_model_runner.py:1590]     64
INFO 06-11 13:56:29 [tpu_model_runner.py:1590]     128
INFO 06-11 13:56:29 [tpu_model_runner.py:1590]     256
INFO 06-11 13:56:29 [tpu_model_runner.py:997] Loading model from scratch...
INFO 06-11 13:56:29 [tpu.py:51] Cannot use None backend on TPU.
INFO 06-11 13:56:29 [tpu.py:54] Using Pallas V1 backend.
INFO 06-11 13:56:30 [weight_utils.py:292] Using model weights format ['*.safetensors']
INFO 06-11 13:56:30 [weight_utils.py:308] Time spent downloading weights for meta-llama/Llama-3.1-8B-Instruct: 0.515513 seconds
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:02<00:06,  2.00s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:02<00:01,  1.01it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.49it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  2.07it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.45it/s]

INFO 06-11 13:56:33 [default_loader.py:272] Loading weights took 2.77 seconds
INFO 06-11 13:56:48 [tpu_model_runner.py:1439] Clear dynamo cache and cached dynamo bytecode.
INFO 06-11 13:56:48 [kv_cache_utils.py:715] GPU KV cache size: 104,960 tokens
INFO 06-11 13:56:48 [kv_cache_utils.py:719] Maximum concurrency for 8,192 tokens per request: 12.81x
INFO 06-11 13:56:50 [tpu_model_runner.py:1149] Compiling the model with different input shapes.
INFO 06-11 13:56:50 [tpu_model_runner.py:1152]   -- num_tokens: 16
/usr/local/lib/python3.10/site-packages/jax/_src/cloud_tpu_init.py:82: UserWarning: Transparent hugepages are not enabled. TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer. If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled")
  warnings.warn(
ERROR 06-11 13:57:14 [core.py:515] EngineCore failed to start.
ERROR 06-11 13:57:14 [core.py:515] Traceback (most recent call last):
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/v1/engine/core.py", line 506, in run_engine_core
ERROR 06-11 13:57:14 [core.py:515]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/v1/engine/core.py", line 390, in __init__
ERROR 06-11 13:57:14 [core.py:515]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/v1/engine/core.py", line 83, in __init__
ERROR 06-11 13:57:14 [core.py:515]     self._initialize_kv_caches(vllm_config)
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/v1/engine/core.py", line 168, in _initialize_kv_caches
ERROR 06-11 13:57:14 [core.py:515]     self.model_executor.initialize_from_config(kv_cache_configs)
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/v1/executor/abstract.py", line 66, in initialize_from_config
ERROR 06-11 13:57:14 [core.py:515]     self.collective_rpc("compile_or_warm_up_model")
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 06-11 13:57:14 [core.py:515]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/utils.py", line 2671, in run_method
ERROR 06-11 13:57:14 [core.py:515]     return func(*args, **kwargs)
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/v1/worker/tpu_worker.py", line 248, in compile_or_warm_up_model
ERROR 06-11 13:57:14 [core.py:515]     self.model_runner.capture_model()
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/v1/worker/tpu_model_runner.py", line 1283, in capture_model
ERROR 06-11 13:57:14 [core.py:515]     self._precompile_backbone()
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/v1/worker/tpu_model_runner.py", line 1153, in _precompile_backbone
ERROR 06-11 13:57:14 [core.py:515]     self._dummy_run(num_tokens)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-11 13:57:14 [core.py:515]     return func(*args, **kwargs)
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/v1/worker/tpu_model_runner.py", line 1074, in _dummy_run
ERROR 06-11 13:57:14 [core.py:515]     out = self.model(input_ids=input_ids,
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
ERROR 06-11 13:57:14 [core.py:515]     return self._call_impl(*args, **kwargs)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
ERROR 06-11 13:57:14 [core.py:515]     return forward_call(*args, **kwargs)
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/model_executor/models/llama.py", line 581, in forward
ERROR 06-11 13:57:14 [core.py:515]     model_output = self.model(input_ids, positions, intermediate_tensors,
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/compilation/decorators.py", line 239, in __call__
ERROR 06-11 13:57:14 [core.py:515]     output = self.compiled_callable(*args, **kwargs)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 699, in compile_wrapper
ERROR 06-11 13:57:14 [core.py:515]     return fn(*args, **kwargs)
ERROR 06-11 13:57:14 [core.py:515]   File "/workspace/vllm/vllm/model_executor/models/llama.py", line 368, in forward
ERROR 06-11 13:57:14 [core.py:515]     def forward(
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn
ERROR 06-11 13:57:14 [core.py:515]     return fn(*args, **kwargs)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1231, in forward
ERROR 06-11 13:57:14 [core.py:515]     return compiled_fn(full_args)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 356, in runtime_wrapper
ERROR 06-11 13:57:14 [core.py:515]     all_outs = call_func_at_runtime_with_args(
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
ERROR 06-11 13:57:14 [core.py:515]     out = normalize_as_list(f(args))
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 722, in inner_fn
ERROR 06-11 13:57:14 [core.py:515]     outs = compiled_fn(args)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 528, in wrapper
ERROR 06-11 13:57:14 [core.py:515]     return compiled_fn(runtime_args)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 100, in g
ERROR 06-11 13:57:14 [core.py:515]     return f(*args)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 37, in fwd
ERROR 06-11 13:57:14 [core.py:515]     compiled_graph = bridge.extract_compiled_graph(model, args)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch_xla/_dynamo/dynamo_bridge.py", line 733, in extract_compiled_graph
ERROR 06-11 13:57:14 [core.py:515]     return extract_compiled_graph_helper(xla_model, xla_args)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch_xla/_dynamo/dynamo_bridge.py", line 853, in extract_compiled_graph_helper
ERROR 06-11 13:57:14 [core.py:515]     return extract_internal(xla_model)
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch_xla/_dynamo/dynamo_bridge.py", line 539, in extract_internal
ERROR 06-11 13:57:14 [core.py:515]     xla_args_need_update) = extract_graph_helper(xla_model,
ERROR 06-11 13:57:14 [core.py:515]   File "/usr/local/lib/python3.10/site-packages/torch_xla/_dynamo/dynamo_bridge.py", line 483, in extract_graph_helper
ERROR 06-11 13:57:14 [core.py:515]     torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, [])
ERROR 06-11 13:57:14 [core.py:515] ValueError: Ran out of memory in memory space vmem while allocating on stack for %name.32 = bf16[16,32,128]{2,1,0:T(8,128)(2,1)S(1)} custom-call(%p16.165, %copy, %p14.163, %p13.162, %copy.289, /*index=5*/%pad_maximum_fusion, %bitcast.1), custom_call_target="tpu_custom_call", operand_layout_constraints={s32[256]{0}, s32[256,32]{1,0}, s32[257]{0}, s32[2]{0}, s32[1]{0}, bf16[16,32,128]{2,1,0}, bf16[410,256,16,128]{3,2,1,0}}. Scoped allocation with size 64.82M and limit 64.00M exceeded scoped vmem limit by 840.0K. It should not be possible to run out of scoped vmem - please file a bug against XLA.
Process EngineCore_0:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/workspace/vllm/vllm/v1/engine/core.py", line 519, in run_engine_core
    raise e
  File "/workspace/vllm/vllm/v1/engine/core.py", line 506, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
  File "/workspace/vllm/vllm/v1/engine/core.py", line 390, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/workspace/vllm/vllm/v1/engine/core.py", line 83, in __init__
    self._initialize_kv_caches(vllm_config)
  File "/workspace/vllm/vllm/v1/engine/core.py", line 168, in _initialize_kv_caches
    self.model_executor.initialize_from_config(kv_cache_configs)
  File "/workspace/vllm/vllm/v1/executor/abstract.py", line 66, in initialize_from_config
    self.collective_rpc("compile_or_warm_up_model")
  File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
  File "/workspace/vllm/vllm/utils.py", line 2671, in run_method
    return func(*args, **kwargs)
  File "/workspace/vllm/vllm/v1/worker/tpu_worker.py", line 248, in compile_or_warm_up_model
    self.model_runner.capture_model()
  File "/workspace/vllm/vllm/v1/worker/tpu_model_runner.py", line 1283, in capture_model
    self._precompile_backbone()
  File "/workspace/vllm/vllm/v1/worker/tpu_model_runner.py", line 1153, in _precompile_backbone
    self._dummy_run(num_tokens)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/vllm/vllm/v1/worker/tpu_model_runner.py", line 1074, in _dummy_run
    out = self.model(input_ids=input_ids,
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/vllm/vllm/model_executor/models/llama.py", line 581, in forward
    model_output = self.model(input_ids, positions, intermediate_tensors,
  File "/workspace/vllm/vllm/compilation/decorators.py", line 239, in __call__
    output = self.compiled_callable(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 699, in compile_wrapper
    return fn(*args, **kwargs)
  File "/workspace/vllm/vllm/model_executor/models/llama.py", line 368, in forward
    def forward(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1231, in forward
    return compiled_fn(full_args)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 356, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 722, in inner_fn
    outs = compiled_fn(args)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 528, in wrapper
    return compiled_fn(runtime_args)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 100, in g
    return f(*args)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 37, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_dynamo/dynamo_bridge.py", line 733, in extract_compiled_graph
    return extract_compiled_graph_helper(xla_model, xla_args)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_dynamo/dynamo_bridge.py", line 853, in extract_compiled_graph_helper
    return extract_internal(xla_model)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_dynamo/dynamo_bridge.py", line 539, in extract_internal
    xla_args_need_update) = extract_graph_helper(xla_model,
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_dynamo/dynamo_bridge.py", line 483, in extract_graph_helper
    torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, [])
ValueError: Ran out of memory in memory space vmem while allocating on stack for %name.32 = bf16[16,32,128]{2,1,0:T(8,128)(2,1)S(1)} custom-call(%p16.165, %copy, %p14.163, %p13.162, %copy.289, /*index=5*/%pad_maximum_fusion, %bitcast.1), custom_call_target="tpu_custom_call", operand_layout_constraints={s32[256]{0}, s32[256,32]{1,0}, s32[257]{0}, s32[2]{0}, s32[1]{0}, bf16[16,32,128]{2,1,0}, bf16[410,256,16,128]{3,2,1,0}}. Scoped allocation with size 64.82M and limit 64.00M exceeded scoped vmem limit by 840.0K. It should not be possible to run out of scoped vmem - please file a bug against XLA.
Traceback (most recent call last):
  File "/usr/local/bin/vllm", line 8, in <module>
    sys.exit(main())
  File "/workspace/vllm/vllm/entrypoints/cli/main.py", line 59, in main
    args.dispatch_function(args)
  File "/workspace/vllm/vllm/entrypoints/cli/serve.py", line 58, in cmd
    uvloop.run(run_server(args))
  File "/usr/local/lib/python3.10/site-packages/uvloop/__init__.py", line 82, in run
    return loop.run_until_complete(wrapper())
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/usr/local/lib/python3.10/site-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
  File "/workspace/vllm/vllm/entrypoints/openai/api_server.py", line 1323, in run_server
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
  File "/workspace/vllm/vllm/entrypoints/openai/api_server.py", line 1343, in run_server_worker
    async with build_async_engine_client(args, client_config) as engine_client:
  File "/usr/local/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/workspace/vllm/vllm/entrypoints/openai/api_server.py", line 155, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
  File "/usr/local/lib/python3.10/contextlib.py", line 199, in __aenter__
    return await anext(self.gen)
  File "/workspace/vllm/vllm/entrypoints/openai/api_server.py", line 191, in build_async_engine_client_from_engine_args
    async_llm = AsyncLLM.from_vllm_config(
  File "/workspace/vllm/vllm/v1/engine/async_llm.py", line 162, in from_vllm_config
    return cls(
  File "/workspace/vllm/vllm/v1/engine/async_llm.py", line 124, in __init__
    self.engine_core = EngineCoreClient.make_async_mp_client(
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 93, in make_async_mp_client
    return AsyncMPClient(vllm_config, executor_class, log_stats,
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 716, in __init__
    super().__init__(
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 422, in __init__
    self._init_engines_direct(vllm_config, local_only,
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 491, in _init_engines_direct
    self._wait_for_engine_startup(handshake_socket, input_address,
  File "/workspace/vllm/vllm/v1/engine/core_client.py", line 511, in _wait_for_engine_startup
    wait_for_engine_startup(
  File "/workspace/vllm/vllm/v1/utils.py", line 494, in wait_for_engine_startup
    raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtpuRelated to Google TPUs

    Type

    No type

    Projects

    Status

    To Triage

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions