Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

from __future__ import annotations

import asyncio
import uuid
from contextlib import ExitStack

import pytest
import torch

from tests.helpers.mark import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
Expand Down Expand Up @@ -150,3 +152,101 @@
assert last_output is not None
assert isinstance(last_output, OmniRequestOutput)
assert last_output.images, "Expected at least one generated image"


@pytest.mark.core_model
@pytest.mark.diffusion
@hardware_test(res={"cuda": "L4"}, num_cards=1)
@pytest.mark.asyncio
async def test_sleep_memory_reclaimed_custom_pipeline():
"""sleep(level=1) must physically reclaim CuMemAllocator-tracked memory for
custom_pipeline.

Regression test for: custom pipelines constructed under ``with target_device:``
(CUDA default-device context) caused safetensors >=0.20.0 to use a
direct-to-GPU fast path (cudaMalloc via the driver API) that bypasses
CuMemAllocator, leaving weights invisible to sleep() and pinned in GPU
memory after the call.

The fix moves custom_pipeline init outside the CUDA context so all weights
go through the caching allocator and are therefore fully reclaimed by
sleep(level=1). A non-zero ``CuMemAllocator.get_current_usage()`` after
sleep is the direct signal that the bypass is still occurring.
"""
with ExitStack() as after:
engine = AsyncOmni(
model=MODEL,
custom_pipeline_args={"pipeline_class": CUSTOM_PIPELINE_CLASS},
worker_extension_cls=WORKER_EXTENSION_CLASS,
enforce_eager=True,
enable_sleep_mode=True,
)
after.callback(engine.shutdown)

assert not await engine.is_sleeping(), "Engine should be awake after creation"

# Measure global VRAM before sleep (driver view; includes inline worker
# thread since inline mode runs in the same process).
torch.cuda.synchronize()

Check failure on line 190 in tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (TID251)

tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py:190:9: TID251 `torch.cuda.synchronize` is banned: Use torch.accelerator.synchronize
free_before, total = torch.cuda.mem_get_info()
used_before_gib = (total - free_before) / 1024**3

# Measure CuMemAllocator-tracked usage before sleep. In inline mode
# the worker runs in a thread pool inside this process, so the allocator
# singleton is shared and can be read directly.
allocator = None
tracked_before = 0
try:
from vllm.device_allocator.cumem import CuMemAllocator

allocator = CuMemAllocator.get_instance()
tracked_before = allocator.get_current_usage()
except Exception:
pass

# Put the engine to sleep; all weights should be offloaded via the pool.
acks = await engine.sleep(level=1)
await asyncio.sleep(0.5) # allow the CUDA driver to settle
torch.cuda.synchronize()

Check failure on line 210 in tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (TID251)

tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py:210:9: TID251 `torch.cuda.synchronize` is banned: Use torch.accelerator.synchronize

# Measure after sleep.
free_after, _ = torch.cuda.mem_get_info()
used_after_gib = (total - free_after) / 1024**3
drop_gib = used_before_gib - used_after_gib

# --- Primary assertion: allocator reports zero tracked memory. ---
# If this fails it means weights were allocated outside the CuMem pool
# (safetensors direct-to-GPU bypass) — the exact regression this test
# is designed to catch.
if allocator is not None:
tracked_after = allocator.get_current_usage()
assert tracked_after == 0, (
f"CuMemAllocator still tracks {tracked_after / 1024**3:.3f} GiB "
f"after sleep(level=1) on custom_pipeline path "
f"(was {tracked_before / 1024**3:.3f} GiB before sleep). "
"Weights were allocated outside the CuMem pool via the "
"safetensors direct-to-GPU fast path — loader-context fix "
"may not be applied."
)

# --- Secondary assertion: physical VRAM or ACK freed_bytes confirms
# reclamation at the driver level.
total_freed_bytes = sum(
(ack.freed_bytes if hasattr(ack, "freed_bytes") else ack.get("freed_bytes", 0))
for ack in acks
if ack is not None
)
freed_gib = total_freed_bytes / 1024**3
assert freed_gib > 0 or drop_gib > 0, (
f"Expected GPU memory to be reclaimed after sleep(level=1) on "
f"custom_pipeline + enable_sleep_mode=True. "
f"CuMemAllocator tracked before={tracked_before / 1024**3:.3f} GiB, "
f"ACK freed={freed_gib:.3f} GiB, global VRAM drop={drop_gib:.3f} GiB."
)

# Engine must report it is sleeping.
assert await engine.is_sleeping()

# Wake up and confirm the engine is functional again.
await engine.wake_up()
assert not await engine.is_sleeping()
22 changes: 22 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,28 @@ class DiffusionOutput:
# memory usage info
peak_memory_mb: float = 0.0

# When True, move all tensor fields (including tensors inside
# ``custom_output``) to CPU at construction time. Useful when the output
# is shipped across process boundaries (e.g. step-execution mode) and the
# receiving side must not initialise a stray CUDA context.
to_cpu: bool = False

def __post_init__(self) -> None:
if not self.to_cpu:
return

def _maybe_to_cpu(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return value.detach().cpu()
return value

self.output = _maybe_to_cpu(self.output)
self.trajectory_timesteps = _maybe_to_cpu(self.trajectory_timesteps)
self.trajectory_latents = _maybe_to_cpu(self.trajectory_latents)
self.trajectory_log_probs = _maybe_to_cpu(self.trajectory_log_probs)
if self.custom_output:
self.custom_output = {k: _maybe_to_cpu(v) for k, v in self.custom_output.items()}


class DiffusionRequestAbortedError(RuntimeError):
"""Raised when a diffusion request ends via user-visible abort."""
Expand Down
41 changes: 28 additions & 13 deletions vllm_omni/diffusion/model_loader/diffusers_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,34 @@ def load_model(
od_config, target_device=device, load_format=load_format, custom_pipeline_name=custom_pipeline_name
)
else:
with target_device:
if load_format == "default":
model = initialize_model(od_config)
elif load_format == "diffusers":
model = DiffusersAdapterPipeline(od_config=od_config, device=target_device)
elif load_format == "custom_pipeline":
from vllm_omni.diffusion.config import set_current_diffusion_config

model_cls = resolve_obj_by_qualname(custom_pipeline_name)
with set_current_diffusion_config(od_config):
model = model_cls(od_config=od_config)
else:
raise ValueError(f"Unknown load_format: {load_format}")
if load_format == "custom_pipeline":
# NOTE: Custom pipelines call HuggingFace `from_pretrained(...).to(device)`
# internally. If we construct them under `with target_device:` (CUDA),
# safetensors takes a direct-to-GPU fast path that calls `cudaMalloc`
# via the driver API and BYPASSES PyTorch's caching allocator.
# That makes those bytes invisible to CuMemAllocator, so `sleep()`
# cannot offload/unmap them and GPU memory stays pinned.
#
# Fix: build the custom pipeline on CPU first (no default device
# context), then explicitly move it to the target device. The
# subsequent `.to(target_device)` issues `torch.empty(..., device=cuda)`
# + `copy_`, which goes through the caching allocator and is fully
# tracked by CuMemAllocator.
from vllm_omni.diffusion.config import set_current_diffusion_config

model_cls = resolve_obj_by_qualname(custom_pipeline_name)
with set_current_diffusion_config(od_config):
model = model_cls(od_config=od_config)
if target_device.type != "cpu":
model.to(target_device)
else:
with target_device:
if load_format == "default":
model = initialize_model(od_config)
elif load_format == "diffusers":
model = DiffusersAdapterPipeline(od_config=od_config, device=target_device)
else:
raise ValueError(f"Unknown load_format: {load_format}")
logger.debug("Loading weights on %s ...", load_device)
if load_format == "diffusers":
# DiffusersAdapterPipeline.load_weights() calls
Expand Down
Loading