Skip to content

Commit 7f3422b

Browse files
author
Flip
committed
Fix: Account for CPU offloading in KV cache memory check
The check_enough_kv_cache_memory() function was not accounting for CPU offloading capacity when validating available memory. This caused the V1 engine to fail with 'No available memory for cache blocks' error even when --cpu-offload-gb was set. This fix adds the CPU offload capacity to the effective available memory before performing the check, allowing 7B-13B models to work correctly with CPU offloading on 12GB GPUs. Fixes #27934
1 parent 8d259fa commit 7f3422b

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

vllm/v1/core/kv_cache_utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,12 @@ def check_enough_kv_cache_memory(
682682
if not kv_cache_spec:
683683
return
684684

685-
if available_memory <= 0:
685+
# Account for CPU offloading when checking memory availability
686+
# When CPU offloading is enabled, effective memory is GPU + CPU offload
687+
cpu_offload_bytes = int(vllm_config.cache_config.cpu_offload_gb * GiB_bytes)
688+
effective_available_memory = available_memory + cpu_offload_bytes
689+
690+
if effective_available_memory <= 0:
686691
raise ValueError(
687692
"No available memory for the cache blocks. "
688693
"Try increasing `gpu_memory_utilization` when "
@@ -692,25 +697,32 @@ def check_enough_kv_cache_memory(
692697
max_model_len = vllm_config.model_config.max_model_len
693698
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
694699

695-
if needed_memory > available_memory:
700+
if needed_memory > effective_available_memory:
696701
# Estimate the maximum model length that can fit in the available memory
697702
estimated_max_len = estimate_max_model_len(
698-
vllm_config, kv_cache_spec, available_memory
703+
vllm_config, kv_cache_spec, effective_available_memory
699704
)
700705
estimated_msg = ""
701706
if estimated_max_len > 0:
702707
estimated_msg = (
703-
"Based on the available memory, "
708+
"Based on the available memory (GPU + CPU offload), "
704709
f"the estimated maximum model length is {estimated_max_len}."
705710
)
706711

712+
offload_info = ""
713+
if cpu_offload_bytes > 0:
714+
offload_info = (
715+
f" (GPU: {available_memory / GiB_bytes:.2f} GiB + "
716+
f"CPU offload: {cpu_offload_bytes / GiB_bytes:.2f} GiB)"
717+
)
718+
707719
raise ValueError(
708720
f"To serve at least one request with the models's max seq len "
709721
f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV "
710722
f"cache is needed, which is larger than the available KV cache "
711-
f"memory ({available_memory / GiB_bytes:.2f} GiB). "
723+
f"memory ({effective_available_memory / GiB_bytes:.2f} GiB{offload_info}). "
712724
f"{estimated_msg} "
713-
f"Try increasing `gpu_memory_utilization` or decreasing "
725+
f"Try increasing `gpu_memory_utilization`, `cpu_offload_gb`, or decreasing "
714726
f"`max_model_len` when initializing the engine."
715727
)
716728

0 commit comments

Comments
 (0)