class SimpleCPUOffloadScheduler:
"""Scheduler-side manager for CPU offloading."""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: "KVCacheConfig | None",
cpu_capacity_bytes: int,
lazy_offload: bool = False,
):
self.vllm_config = vllm_config
self.kv_cache_config = kv_cache_config
self.enable_kv_cache_events = (
vllm_config.kv_events_config is not None
and vllm_config.kv_events_config.enable_kv_cache_events
)
# NOTE: We use the same block size for both GPU and CPU.
self.block_size = vllm_config.cache_config.block_size
# Derive a CPU KVCacheConfig from the GPU config and build a coordinator
assert kv_cache_config is not None
self.cpu_kv_cache_config = self._derive_cpu_config(
kv_cache_config, cpu_capacity_bytes
)
self.num_cpu_blocks = self.cpu_kv_cache_config.num_blocks
# Find the full attention kv group for prefix cache matching.
self.fa_gidx = -1
for g_idx, g in enumerate(self.cpu_kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
self.fa_gidx = g_idx
break
assert 0 <= self.fa_gidx < len(self.cpu_kv_cache_config.kv_cache_groups)
logger.info(
"SimpleCPUOffloadScheduler: Allocating %d CPU blocks (%.2f GB, mode=%s)",
self.num_cpu_blocks,
cpu_capacity_bytes / (1024**3),
"lazy" if lazy_offload else "eager",
)
# TODO (yifan): maybe need to enable kv_cache_events and metrics_collector here.
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
assert dcp_world_size == 1 and pcp_world_size == 1
self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator(
kv_cache_config=self.cpu_kv_cache_config,
max_model_len=vllm_config.model_config.max_model_len,
use_eagle=False,
enable_caching=True,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=self.block_size,
)
self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool
# GPU block pool reference - bound after scheduler builds kv_cache_manager
self._gpu_block_pool: BlockPool | None = None
# Load metadata
self._reqs_to_load: dict[str, LoadRequestState] = {}
# Inverse map: load_event_idx -> req_ids. Keyed by load_event_idx because
# the worker reports completions by event index, not request id.
self._load_event_to_reqs: dict[int, list[str]] = {}
# Store metadata
self._lazy_mode = lazy_offload
# Lazy mode: use a cursor to track the last scanned block in the GPU free queue.
self._cursor: KVCacheBlock | None = None
if self._lazy_mode:
self._target_free = self._estimate_lazy_target_blocks(
kv_cache_config,
vllm_config.scheduler_config.max_num_batched_tokens,
)
else:
self._target_free = 0
self._store_event_to_blocks: dict[int, TransferMeta] = {}
# Eager mode only
self._reqs_to_store: dict[str, StoreRequestState] = {}
self._store_event_to_reqs: dict[int, list[str]] = {}
# Event counters
self._load_event_counter: int = 0
self._store_event_counter: int = 0
# For TP/PP: track partial store completions across steps.
# Events must be reported by all world_size workers before considered complete.
self._expected_worker_count = vllm_config.parallel_config.world_size
self._store_event_pending_counts: dict[int, int] = {}
@staticmethod
def _derive_cpu_config(
gpu_config: "KVCacheConfig", cpu_capacity_bytes: int
) -> "KVCacheConfig":
"""Derive a CPU KVCacheConfig from the GPU config.
Same kv_cache_groups, num_blocks scaled by CPU/GPU memory ratio."""
# Import here to avoid potential circular imports
from vllm.v1.kv_cache_interface import KVCacheConfig as KVCacheConfigCls
from vllm.v1.kv_cache_interface import KVCacheTensor
assert len(gpu_config.kv_cache_tensors) > 0
gpu_total_bytes = sum(t.size for t in gpu_config.kv_cache_tensors)
num_gpu_blocks = gpu_config.num_blocks
num_cpu_blocks = max(1, num_gpu_blocks * cpu_capacity_bytes // gpu_total_bytes)
# Create CPU kv_cache_tensors mirroring GPU by scaling size proportionally.
cpu_tensors = [
KVCacheTensor(
size=t.size // num_gpu_blocks * num_cpu_blocks,
shared_by=list(t.shared_by),
)
for t in gpu_config.kv_cache_tensors
]
return KVCacheConfigCls(
num_blocks=num_cpu_blocks,
kv_cache_tensors=cpu_tensors,
kv_cache_groups=gpu_config.kv_cache_groups,
)
@staticmethod
def _estimate_lazy_target_blocks(
kv_cache_config: "KVCacheConfig", max_num_batched_tokens: int
) -> int:
"""GPU blocks to keep available (free/offloaded) per step in lazy mode."""
WATERMARK_RATIO = 1.0 # Reserve larger space to avoid running out of GPU blocks
target = 0
for g in kv_cache_config.kv_cache_groups:
spec = g.kv_cache_spec
if isinstance(spec, MambaSpec):
target += 2
elif isinstance(spec, SlidingWindowSpec):
target += cdiv(spec.sliding_window, spec.block_size) + 1
else:
target += cdiv(max_num_batched_tokens, spec.block_size)
return int(target * (1 + WATERMARK_RATIO))
def bind_gpu_block_pool(self, gpu_block_pool: BlockPool) -> None:
"""Bind GPU block pool so that we can touch blocks during stores.
Called by Scheduler after kv_cache_manager is ready."""
self._gpu_block_pool = gpu_block_pool
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
"""Return (num_new_tokens, is_async) from consecutive CPU cache hits."""
skipped = num_computed_tokens // self.block_size
remaining_hashes = request.block_hashes[skipped:]
if not remaining_hashes:
return 0, False
# Must recompute at least the last token, matching the logic in
# kv_cache_manager.get_computed_blocks().
max_hit_len = request.num_tokens - 1 - num_computed_tokens
if max_hit_len <= 0:
return 0, False
_, hit_length = self.cpu_coordinator.find_longest_cache_hit(
remaining_hashes, max_hit_len
)
if hit_length > 0:
return hit_length, True
return 0, False
# TODO(yifan): this API now only matches the suffix part of the prefix cache. A more
# general API should scan blocks in both GPU and CPU block pool in a single pass.
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
) -> None:
req_id = request.request_id
block_ids_by_group = blocks.get_block_ids()
num_groups = len(block_ids_by_group)
# Store tracking (eager mode only). Register the request;
# block IDs are accumulated from scheduler_output in
# _prepare_eager_store_specs via yield_req_data.
if not self._lazy_mode and req_id not in self._reqs_to_store:
self._reqs_to_store[req_id] = StoreRequestState(
request=request,
block_ids=tuple([] for _ in range(num_groups)),
num_stored_blocks=[0] * num_groups,
)
if num_external_tokens == 0:
return
num_blocks_to_load = num_external_tokens // self.block_size
assert num_blocks_to_load > 0
skipped = sum(blk.block_hash is not None for blk in blocks.blocks[self.fa_gidx])
num_computed_tokens = skipped * self.block_size
hashes_to_load = request.block_hashes[skipped : skipped + num_blocks_to_load]
# Find CPU cached blocks across all groups.
max_hit_len = len(hashes_to_load) * self.block_size
cpu_hit_blocks, hit_length = self.cpu_coordinator.find_longest_cache_hit(
hashes_to_load, max_hit_len
)
assert hit_length == num_external_tokens, (
f"Expected {num_external_tokens} hit tokens, got {hit_length}"
)
# Build transfer pairs across all groups.
total_computed_tokens = num_computed_tokens + num_external_tokens
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
gpu_block_ids: list[int] = []
cpu_block_ids: list[int] = []
cpu_blocks_to_touch: list[KVCacheBlock] = []
for g in range(num_groups):
cpu_blocks_g = cpu_hit_blocks[g]
n_ext_g = len(cpu_blocks_g)
if n_ext_g == 0:
continue
# Number of blocks in the computed range for this group.
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
n_computed_g = cdiv(total_computed_tokens, g_block_size)
# Back-trace: ext blocks sit at the tail of the computed range.
gpu_ext_start = n_computed_g - n_ext_g
group_gpu_ids = block_ids_by_group[g]
for i, cpu_blk in enumerate(cpu_blocks_g):
# Skip null blocks (e.g. sliding window or mamba padding).
if cpu_blk.is_null:
continue
gpu_block_ids.append(group_gpu_ids[gpu_ext_start + i])
cpu_block_ids.append(cpu_blk.block_id)
cpu_blocks_to_touch.append(cpu_blk)
# Touch CPU blocks to prevent eviction during async load.
self.cpu_block_pool.touch(cpu_blocks_to_touch)
# Touch GPU blocks to prevent freeing during async load
assert self._gpu_block_pool is not None
self._gpu_block_pool.touch(
[self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
)
assert self._reqs_to_load.get(req_id) is None
self._reqs_to_load[req_id] = LoadRequestState(
request=request, transfer_meta=TransferMeta(gpu_block_ids, cpu_block_ids)
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> SimpleCPUOffloadMetadata:
# --- Stores ---
store_event = -1
store_gpu, store_cpu, store_req_ids = self.prepare_store_specs(scheduler_output)
if store_gpu:
store_event = self._store_event_counter
self._store_event_counter += 1
self._store_event_to_blocks[store_event] = TransferMeta(
store_gpu, store_cpu
)
if store_req_ids: # For eager mode only, track req->blocks mapping
self._store_event_to_reqs[store_event] = store_req_ids
for req_id in store_req_ids:
store_state = self._reqs_to_store.get(req_id)
if store_state is not None:
store_state.store_events.add(store_event)
# --- Loads ---
load_event = -1
load_gpu: list[int] = []
load_cpu: list[int] = []
load_req_ids: list[str] = []
for req_id, load_state in self._reqs_to_load.items():
if load_state.load_event is not None:
continue
assert load_state.transfer_meta is not None
load_gpu.extend(load_state.transfer_meta.gpu_block_ids)
load_cpu.extend(load_state.transfer_meta.cpu_block_ids)
load_req_ids.append(req_id)
if load_req_ids:
load_event = self._load_event_counter
self._load_event_counter += 1
for req_id in load_req_ids:
self._reqs_to_load[req_id].load_event = load_event
self._load_event_to_reqs[load_event] = load_req_ids
result = SimpleCPUOffloadMetadata(
load_event=load_event,
load_gpu_blocks=load_gpu,
load_cpu_blocks=load_cpu,
load_event_to_reqs=self._load_event_to_reqs,
store_event=store_event,
store_gpu_blocks=store_gpu,
store_cpu_blocks=store_cpu,
need_flush=bool(scheduler_output.preempted_req_ids),
)
return result
def prepare_store_specs(
self, scheduler_output: SchedulerOutput
) -> tuple[list[int], list[int], list[str]]:
"""Prepare store specs for the store event."""
if self._lazy_mode:
return self._prepare_lazy_store_specs()
else:
return self._prepare_eager_store_specs(scheduler_output)
def _prepare_lazy_store_specs(
self,
) -> tuple[list[int], list[int], list[str]]:
"""Single-pass cursor walk: offload cached GPU blocks near eviction.
Walks the GPU free queue from the cursor, counting blocks that are
free-or-offloaded (safe for the allocator to evict). Stops when
target_free blocks are covered or CPU capacity is reached.
"""
gpu_pool = self._gpu_block_pool
if gpu_pool is None or self._target_free <= 0:
return [], [], []
free_queue = gpu_pool.free_block_queue
cpu_pool = self.cpu_block_pool
num_cpu_free = cpu_pool.get_num_free_blocks()
# Validate cursor: stale if block was removed from free queue.
if self._cursor is not None and self._cursor.ref_cnt > 0:
self._cursor = None
# Determine start node.
if self._cursor is None:
node = free_queue.fake_free_list_head.next_free_block
else:
node = self._cursor.next_free_block
tail = free_queue.fake_free_list_tail
gpu_ids: list[int] = []
block_hashes: list[bytes] = []
covered = 0
last_visited = self._cursor
while (
node is not None
and node is not tail
and covered < self._target_free
and len(gpu_ids) < num_cpu_free
):
last_visited = node
bhash = node.block_hash
if (
bhash is not None
and not node.is_null
and cpu_pool.cached_block_hash_to_block.get_one_block(bhash) is None
):
gpu_ids.append(node.block_id)
block_hashes.append(bhash)
covered += 1
node = node.next_free_block
self._cursor = last_visited
# Batch-allocate CPU blocks and stamp hashes.
if gpu_ids:
cpu_blocks = cpu_pool.get_new_blocks(len(gpu_ids))
cpu_ids = [blk.block_id for blk in cpu_blocks]
for cpu_blk, bhash in zip(cpu_blocks, block_hashes): # type: ignore[assignment]
cpu_blk._block_hash = bhash # type: ignore[assignment]
# Touch GPU blocks to prevent eviction during async copy.
gpu_pool.touch([gpu_pool.blocks[bid] for bid in gpu_ids])
else:
cpu_ids = []
return gpu_ids, cpu_ids, []
def _prepare_eager_store_specs(
self, scheduler_output: SchedulerOutput
) -> tuple[list[int], list[int], list[str]]:
"""Identify newly computed blocks to offload from scheduler requests.
Only considers blocks whose KV data has been **confirmed computed** by
the GPU. This means blocks from the current step are NOT stored until the
next step. If a request finishes in the same step as its last full block,
that block may be missed. (TODO: flush on finish.)
Returns:
(gpu_block_ids, cpu_block_ids, req_ids) for the store event.
"""
merged_gpu_block_ids: list[int] = []
merged_cpu_block_ids: list[int] = []
req_ids: list[str] = []
gpu_block_pool = self._gpu_block_pool
if gpu_block_pool is None:
return [], [], []
cpu_block_pool = self.cpu_block_pool
num_free = cpu_block_pool.get_num_free_blocks()
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
num_groups = len(kv_cache_groups)
gpu_blocks_this_step: set[int] = set()
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
state = self._reqs_to_store.get(req_id)
if state is None or state.finished:
continue
# Accumulate new block IDs.
if preempted:
state.block_ids = tuple([] for _ in range(num_groups))
state.num_stored_blocks = [0] * num_groups
if new_block_id_groups:
for g in range(min(num_groups, len(new_block_id_groups))):
if new_block_id_groups[g] is not None:
state.block_ids[g].extend(new_block_id_groups[g])
num_new_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
if num_new_tokens == 0:
continue
block_ids_by_group = state.block_ids
if not block_ids_by_group:
continue
# --- Phase 1: Scan blocks, classify as cached vs to-store ---
gpu_block_ids: list[int] = []
block_hashes_to_store: list[bytes] = []
advanced_per_group: list[int] = [0] * num_groups
out_of_space = False
# Confirmed tokens: KV data written and visible to all streams.
req = state.request
confirmed_tokens = req.num_computed_tokens - req.num_output_placeholders
for g in range(num_groups):
# FIXME (yifan): handle CPU cache eviction, where
# num_stored_blocks can be stale and omit evicted blocks in
# the middle of the request.
already_stored_g = state.num_stored_blocks[g]
group_gpu_ids = block_ids_by_group[g]
# Cap to blocks with confirmed KV data.
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
ready_blocks_g = confirmed_tokens // g_block_size
scannable = group_gpu_ids[already_stored_g:ready_blocks_g]
for gpu_block_id in scannable:
gpu_block = gpu_block_pool.blocks[gpu_block_id]
if gpu_block.is_null:
advanced_per_group[g] += 1
continue
bhash_with_group = gpu_block.block_hash
if bhash_with_group is None:
break
# Check if this group's data is already scheduled for store
# in this step or already cached in CPU.
if (
gpu_block_id in gpu_blocks_this_step
or cpu_block_pool.cached_block_hash_to_block.get_one_block(
bhash_with_group
)
is not None
):
advanced_per_group[g] += 1
continue
if num_free <= 0:
out_of_space = True
break
num_free -= 1
gpu_block_ids.append(gpu_block_id)
block_hashes_to_store.append(bhash_with_group)
advanced_per_group[g] += 1
if out_of_space:
break
# --- Phase 2: Batch allocate CPU blocks and stamp hashes ---
n_to_alloc = len(gpu_block_ids)
if n_to_alloc > 0:
cpu_blocks_alloc = cpu_block_pool.get_new_blocks(n_to_alloc)
cpu_block_ids = [blk.block_id for blk in cpu_blocks_alloc]
for cpu_blk, bhash in zip(cpu_blocks_alloc, block_hashes_to_store):
cpu_blk._block_hash = bhash # type: ignore[assignment]
else:
cpu_block_ids = []
if cpu_block_ids:
req_ids.append(req_id)
merged_gpu_block_ids.extend(gpu_block_ids)
merged_cpu_block_ids.extend(cpu_block_ids)
gpu_blocks_this_step.update(gpu_block_ids)
# Touch GPU blocks to prevent freeing during async copy
gpu_block_pool.touch(
[gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
)
logger.debug(
"Request %s: Scheduling store of %d blocks to CPU (%d groups)",
req_id,
len(cpu_block_ids),
num_groups,
)
# Advance per-group cursors (includes cached hits + newly stored)
for g in range(num_groups):
state.num_stored_blocks[g] += advanced_per_group[g]
return merged_gpu_block_ids, merged_cpu_block_ids, req_ids
def update_connector_output(self, connector_output: KVConnectorOutput) -> None:
"""Handle async transfer completions from worker.
Load completions arrive via finished_recving (real req_ids).
Store completions arrive via kv_connector_worker_meta as
per-event worker counts. We accumulate across steps and process
a store event only when all workers have reported completion.
"""
# --- Load completions ---
for req_id in list(connector_output.finished_recving or []):
self._cleanup_load_request(req_id)
# --- Store completions ---
meta = connector_output.kv_connector_worker_meta
if not isinstance(meta, SimpleCPUOffloadWorkerMetadata):
return
for event_idx, count in meta.completed_store_events.items():
total = self._store_event_pending_counts.get(event_idx, 0) + count
if total >= self._expected_worker_count:
self._store_event_pending_counts.pop(event_idx, None)
self._process_store_event(event_idx)
else:
self._store_event_pending_counts[event_idx] = total
def _process_store_event(self, event_idx: int) -> None:
"""Process a fully-completed store event."""
transfer = self._store_event_to_blocks.pop(event_idx)
self._process_store_completion(transfer.gpu_block_ids, transfer.cpu_block_ids)
logger.debug(
"Store event %d completed: cached %d blocks to CPU",
event_idx,
len(transfer.cpu_block_ids),
)
# Eager only: update per-req state
if not self._lazy_mode:
for req_id in self._store_event_to_reqs.pop(event_idx, []):
state = self._reqs_to_store.get(req_id)
if state is None:
continue
state.store_events.discard(event_idx)
if state.finished and not state.store_events:
self._cleanup_store_request(req_id)
def _process_store_completion(
self, gpu_block_ids: list[int], cpu_block_ids: list[int]
) -> None:
"""Cache CPU blocks per-group and release GPU refs.
Block hashes were stamped on CPU blocks at allocation time (in
``_prepare_*_store_specs``). Here we just register them in the
cache map so they become discoverable by the load path.
"""
assert len(cpu_block_ids) == len(gpu_block_ids)
cpu_blocks = [self.cpu_block_pool.blocks[bid] for bid in cpu_block_ids]
for cpu_block in cpu_blocks:
bhash = cpu_block.block_hash
assert bhash is not None
self.cpu_block_pool.cached_block_hash_to_block.insert(bhash, cpu_block)
# Free CPU and GPU blocks' ref counts to turn them into prefix cache
self.cpu_block_pool.free_blocks(cpu_blocks)
assert self._gpu_block_pool is not None
self._gpu_block_pool.free_blocks(
self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids
)
def has_pending_stores(self) -> bool:
"""Return True if there are in-flight store transfers."""
return bool(self._store_event_to_blocks)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""Always returns (False, None). GPU blocks are protected by ref_cnt,
so the scheduler can free blocks immediately."""
req_id = request.request_id
# Handle load: defer cleanup if load is in-flight
load_state = self._reqs_to_load.get(req_id)
if load_state is not None:
if load_state.load_event is not None:
load_state.finished = True # Defer: load in-flight
else:
self._cleanup_load_request(req_id)
# Handle store (eager mode only): defer cleanup if stores in-flight
if not self._lazy_mode:
store_state = self._reqs_to_store.get(req_id)
if store_state is not None:
if store_state.store_events:
store_state.finished = True # Defer: stores in-flight
else:
self._cleanup_store_request(req_id)
return False, None
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
return self.request_finished(request, block_ids=[])
def _cleanup_load_request(self, req_id: str) -> None:
"""Release all load resources for a request.
Shared between request_finished() and update_connector_output() paths.
Removes the request from _reqs_to_load, cleans up event mappings,
and frees CPU/GPU touch refs.
"""
state = self._reqs_to_load.pop(req_id, None)
if state is None:
return
# Remove from load event mapping (only this req, not whole event)
if state.load_event is not None:
reqs = self._load_event_to_reqs.get(state.load_event)
if reqs is not None:
with contextlib.suppress(ValueError):
reqs.remove(req_id)
if not reqs:
self._load_event_to_reqs.pop(state.load_event, None)
if state.transfer_meta is not None:
# Free CPU touch refs
self.cpu_block_pool.free_blocks(
self.cpu_block_pool.blocks[bid]
for bid in state.transfer_meta.cpu_block_ids
)
# Free GPU touch refs
assert self._gpu_block_pool is not None
self._gpu_block_pool.free_blocks(
self._gpu_block_pool.blocks[bid]
for bid in state.transfer_meta.gpu_block_ids
)
def _cleanup_store_request(self, req_id: str) -> None:
"""Release store metadata for a request.
Metadata-only cleanup but no block freeing. Job completion handles
block caching and GPU ref freeing via _process_store_completion().
"""
state = self._reqs_to_store.pop(req_id, None)
if state is None:
return
for event_idx in list(state.store_events):
if (reqs := self._store_event_to_reqs.get(event_idx)) is not None:
with contextlib.suppress(ValueError):
reqs.remove(req_id)
if not reqs:
self._store_event_to_reqs.pop(event_idx, None)
state.store_events.clear()
def take_events(self) -> Iterable[KVCacheEvent]:
return self.cpu_block_pool.take_events()