Skip to content

fix(vllm): force spawn for single-GPU EngineCore and fix health check scheduling#5030

Open
m199369309 wants to merge 4 commits into
xorbitsai:mainfrom
m199369309:fix/vllm-single-gpu-deadlock-and-health-check
Open

fix(vllm): force spawn for single-GPU EngineCore and fix health check scheduling#5030
m199369309 wants to merge 4 commits into
xorbitsai:mainfrom
m199369309:fix/vllm-single-gpu-deadlock-and-health-check

Conversation

@m199369309

Copy link
Copy Markdown
Collaborator

Summary

  • Force spawn for single-GPU vLLM v1 EngineCore creation to avoid fork deadlock. When tensor_parallel_size=1, vLLM's MultiprocExecutor creates EngineCore via mp.get_context("fork").Process(), inheriting the parent's multi-thread lock state and causing permanent deadlock in the request processing loop. Setting VLLM_WORKER_MULTIPROC_METHOD=spawn creates a clean process. vLLM natively supports this env var (vllm/envs.py) and ZMQ IPC uses path strings, so spawn is safe.
  • Move health check creation from load() to wait_for_load() — the old code created it when self._engine was still None for threaded paths (both multi-GPU and single-GPU), silently skipping health check entirely.
  • Fix AsyncEngineDeadError import compatibility — removed from vllm.engine.async_llm_engine in vLLM 0.19.0+. Module-level try/except in utils.py with RuntimeError fallback in _check_healthy().
  • Use call_soon_threadsafe + create_task instead of run_coroutine_threadsafe — the latter wraps coroutines in concurrent.futures.Future whose exceptions are silently swallowed; create_task produces asyncio.Task with visible exceptions.
  • Wait for loading thread in stop() before shutting down the engine, preventing orphan EngineCore subprocesses.

Test plan

  • Single-GPU Qwen3-8B (BF16): 3 consecutive chat requests succeed (previously hung indefinitely)
  • Single-GPU Llama-3.1-8B (BF16): chat requests succeed
  • "Begin to check health of vLLM" log appears for the first time (previously 0 occurrences across all deployments)
  • No Task exception was never retrieved errors
  • Multi-GPU (TP=2) deployment unaffected (uses XinferenceDistributedExecutorV1, no fork)
  • pytest xinference/model/llm/vllm/tests/ passes

Root cause analysis

Single-GPU vLLM v1 deployments (Qwen3-8B, Llama-3.1-8B, Qwen2.5-Instruct) experienced permanent request hangs:

  • Model loaded successfully, CUDA graphs captured
  • Any inference request hung indefinitely with no error logs
  • EngineCore subprocess stayed alive (stats telemetry continued)
  • serve request count kept incrementing, no timeout

The fork deadlock occurs because the parent process (ModelActor subpool) has multiple active threads (MainThread event loop, _loading_thread, asyncio workers). Fork copies all memory including lock states, but the child only has one thread — it cannot release locks held by the other threads.

Multi-GPU deployments are unaffected because they use XinferenceDistributedExecutorV1 (xoscar actors, no fork).

🤖 Generated with Claude Code

… scheduling

Two fixes for vLLM v1 single-GPU deployments:

1. Force spawn for single-GPU EngineCore creation to avoid fork deadlock.
   vLLM v1 creates EngineCore via mp.get_context("fork").Process(), which
   inherits the parent's multi-thread lock state (MainThread event loop,
   asyncio worker threads). The child process only has one active thread
   and cannot release locks held by other threads, causing the request
   processing loop to deadlock permanently. Setting
   VLLM_WORKER_MULTIPROC_METHOD=spawn creates a clean process without
   inherited locks. vLLM natively supports this env var and ZMQ IPC uses
   path strings (not fd inheritance), so spawn is safe.

2. Move health check creation from load() to wait_for_load() and use
   call_soon_threadsafe + create_task instead of run_coroutine_threadsafe.
   The old code created the health check in load() when self._engine was
   still None for threaded paths, silently skipping it. The new code also
   fixes AsyncEngineDeadError import compatibility (removed from
   vllm.engine.async_llm_engine in vLLM 0.19.0+) by using a module-level
   try/except in utils.py and catching RuntimeError as fallback.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@XprobeBot XprobeBot added bug Something isn't working gpu labels Jun 15, 2026
@XprobeBot XprobeBot added this to the v2.x milestone Jun 15, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the vLLM engine loading process to prevent fork deadlocks by spawning a dedicated loading thread, and defers the health check task creation until the engine is fully loaded. Feedback on these changes highlights three key issues: a critical bug in vllm_check on newer vLLM versions where the health check wrapper is bypassed and a TypeError is raised; an issue where the health check task is never assigned to self._check_health_task due to the asynchronous nature of call_soon_threadsafe, preventing cancellation; and the use of a bare except: clause that should be replaced with except Exception:.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +21 to 40
try:
from vllm.engine.async_llm_engine import AsyncEngineDeadError
except ImportError:
# vLLM 0.19.0+ removed AsyncEngineDeadError from this module.
# It is a subclass of RuntimeError, so except RuntimeError suffices.
AsyncEngineDeadError = None # type: ignore[assignment,misc]


def vllm_check(fn):
try:
from vllm.engine.async_llm_engine import AsyncEngineDeadError
except Exception:
if AsyncEngineDeadError is None:
return fn

@functools.wraps(fn)
async def _async_wrapper(self, *args, **kwargs):
try:
return await fn(self, *args, **kwargs)
except AsyncEngineDeadError:
except (AsyncEngineDeadError, RuntimeError):
logger.info("Detecting vLLM is not health, prepare to quit the process")
try:
self.stop()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical bug in vllm_check when running on vLLM 0.19.0+:

  1. AsyncEngineDeadError is imported as None if the import fails.
  2. If AsyncEngineDeadError is None, vllm_check returns fn directly without wrapping it. This means RuntimeError is never caught for any decorated functions (like async_generate and async_chat), completely disabling the health check and auto-recovery mechanism on newer vLLM versions.
  3. If the function were wrapped, catching (AsyncEngineDeadError, RuntimeError) where AsyncEngineDeadError is None would raise a TypeError: catching classes that do not inherit from BaseException is not allowed at runtime.

To fix both issues, we should fallback AsyncEngineDeadError to RuntimeError when the import fails, and always wrap the function.

try:
    from vllm.engine.async_llm_engine import AsyncEngineDeadError
except ImportError:
    # vLLM 0.19.0+ removed AsyncEngineDeadError from this module.
    # It is a subclass of RuntimeError, so except RuntimeError suffices.
    AsyncEngineDeadError = RuntimeError  # type: ignore[assignment,misc]


def vllm_check(fn):
    @functools.wraps(fn)
    async def _async_wrapper(self, *args, **kwargs):
        try:
            return await fn(self, *args, **kwargs)
        except (AsyncEngineDeadError, RuntimeError):
            logger.info("Detecting vLLM is not health, prepare to quit the process")
            try:
                self.stop()
            except Exception:
                # ignore error when stop
                pass
            # Just kill the process and let xinference auto-recover the model
            os._exit(1)

    return _async_wrapper

Comment thread xinference/model/llm/vllm/core.py Outdated
Comment on lines +691 to +699
self._check_health_task = None
if self._engine and hasattr(self._engine, "check_health") and self._loop:
logger.info(
"Creating vLLM health check task for model %s",
self.model_uid,
)
self._loop.call_soon_threadsafe(
self._loop.create_task, self._check_healthy()
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The health check task self._check_health_task is initialized to None and never updated with the created task because self._loop.call_soon_threadsafe runs asynchronously and its return value is discarded. As a result, self._check_health_task remains None forever, and calling self._check_health_task.cancel() in stop() will have no effect, leaving the health check task running as an orphan task.

To fix this, we should define a thread-safe helper function that assigns the created task to self._check_health_task when executed on the event loop, and also check if the engine is still active to avoid race conditions.

Suggested change
self._check_health_task = None
if self._engine and hasattr(self._engine, "check_health") and self._loop:
logger.info(
"Creating vLLM health check task for model %s",
self.model_uid,
)
self._loop.call_soon_threadsafe(
self._loop.create_task, self._check_healthy()
)
self._check_health_task = None
if self._engine and hasattr(self._engine, "check_health") and self._loop:
logger.info(
"Creating vLLM health check task for model %s",
self.model_uid,
)
def _start_health_check():
if self._engine is not None:
self._check_health_task = self._loop.create_task(self._check_healthy())
self._loop.call_soon_threadsafe(_start_health_check)

Comment thread xinference/model/llm/vllm/core.py Outdated
Comment on lines +659 to +661
except: # noqa: E722
logger.exception("Creating vllm engine failed")
self._loading_error = sys.exc_info()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: clause is discouraged in PEP 8 as it catches system-exiting exceptions like SystemExit and KeyboardInterrupt, which can make it harder to interrupt the program and can disguise other issues. It is highly recommended to use except Exception: instead.

Suggested change
except: # noqa: E722
logger.exception("Creating vllm engine failed")
self._loading_error = sys.exc_info()
except Exception: # noqa: E722
logger.exception("Creating vllm engine failed")
self._loading_error = sys.exc_info()

m199369309 and others added 3 commits June 15, 2026 16:27
…lback

- Fallback AsyncEngineDeadError to RuntimeError instead of None, so
  vllm_check decorator always wraps functions even on vLLM 0.19.0+
- Use helper function in call_soon_threadsafe to capture health check
  task reference, preventing orphan task that can't be cancelled
- Replace bare except: with except Exception: per PEP 8

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Flake8 F401: AsyncEngineDeadError is no longer directly used in
core.py since _check_healthy() now catches RuntimeError instead.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The CI environment may have a partial or broken vllm installation where
importing vllm.engine.async_llm_engine raises non-ImportError exceptions
(e.g., AttributeError, RuntimeError during module init). Using except
Exception (matching the original code before this PR) ensures the module
loads safely even when vllm is broken or partially installed.

This fixes cluster startup failures in CI where test_cmdline and other
integration tests using the setup fixture were timing out.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working gpu

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants