Skip to content

[Issue]: NCCL 2.27.5 fails with "operation not supported" (init.cc:426) on VMware vGPU when UVM is disabled, ignoring CUMEM disable flags #2101

@elliott-davis

Description

@elliott-davis

How is this issue impacting you?

Application crash

Share Your Debug Logs

Description

When initializing a multi-GPU NCCL process group (e.g., via PyTorch dist.init_process_group or vLLM with tensor parallelism > 1) on a VMware ESXi VM utilizing NVIDIA vGPU, NCCL 2.27.5 crashes during initialization with Cuda failure
'operation not supported' at init.cc:426.
This occurs because the ESXi VM does not have Unified Virtual Memory (UVM) enabled (pciPassthruX.cfg.enable_uvm=1 is omitted). While enabling UVM resolves the NCCL crash, enabling UVM completely breaks VMware vMotion, which is a strict
enterprise requirement for our infrastructure.
Furthermore, attempting to bypass the new cuMem allocation paths by setting NCCL_CUMEM_ENABLE=0 and NCCL_CUMEM_HOST_ENABLE=0 does not prevent the crash. The failure happens early in init.cc, suggesting these environment variables are either
ignored or checked too late in the initialization sequence.
NCCL 2.21.5 works perfectly in this exact same environment without UVM enabled, proving that the underlying NVLink/P2P hardware topology is fully functional.

Logs

/usr/local/lib/python3.12/dist-packages/transformers/utils/hub.py:110: FutureWarning: Using TRANSFORMERS_CACHE is deprecated and will be removed in v5 of Transformers. Use HF_HOME instead.
warnings.warn(
/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/chat_completion/protocol.py:346: SyntaxWarning: invalid escape sequence '\e'
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/completion/protocol.py:176: SyntaxWarning: invalid escape sequence '\e'
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
(APIServer pid=355) INFO 04-13 19:51:46 [utils.py:302]
(APIServer pid=355) INFO 04-13 19:51:46 [utils.py:302] █ █ █▄ ▄█
(APIServer pid=355) INFO 04-13 19:51:46 [utils.py:302] ▄▄ ▄█ █ █ █ ▀▄▀ █ version 0.17.0
(APIServer pid=355) INFO 04-13 19:51:46 [utils.py:302] █▄█▀ █ █ █ █ model /data/model
(APIServer pid=355) INFO 04-13 19:51:46 [utils.py:302] ▀▀ ▀▀▀▀▀ ▀▀▀▀▀ ▀ ▀
(APIServer pid=355) INFO 04-13 19:51:46 [utils.py:302]
(APIServer pid=355) INFO 04-13 19:51:46 [utils.py:238] non-default args: {'enable_auto_tool_choice': True, 'tool_call_parser': 'openai', 'port': 3080, 'model': '/data/model', 'served_model_name': ['gpt-oss-120b-elliott'], 'tensor_parallel_size': 2}
(APIServer pid=355) INFO 04-13 19:51:46 [model.py:531] Resolved architecture: GptOssForCausalLM
(APIServer pid=355) ERROR 04-13 19:51:46 [repo_utils.py:47] Error retrieving safetensors: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/data/model'. Use repo_type argument if needed., retrying 1 of 2
(APIServer pid=355) ERROR 04-13 19:51:48 [repo_utils.py:45] Error retrieving safetensors: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/data/model'. Use repo_type argument if needed.
(APIServer pid=355) INFO 04-13 19:51:48 [model.py:1889] Downcasting torch.float32 to torch.bfloat16.
(APIServer pid=355) INFO 04-13 19:51:48 [model.py:1554] Using max model len 131072
(APIServer pid=355) INFO 04-13 19:51:51 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=8192.
(APIServer pid=355) INFO 04-13 19:51:51 [config.py:358] Overriding max cuda graph capture size to 1024 for performance.
(APIServer pid=355) INFO 04-13 19:51:51 [vllm.py:747] Asynchronous scheduling is enabled.
(APIServer pid=355) :1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
(APIServer pid=355) :1301: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.
/usr/local/lib/python3.12/dist-packages/transformers/utils/hub.py:110: FutureWarning: Using TRANSFORMERS_CACHE is deprecated and will be removed in v5 of Transformers. Use HF_HOME instead.
warnings.warn(
/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/chat_completion/protocol.py:346: SyntaxWarning: invalid escape sequence '\e'
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/completion/protocol.py:176: SyntaxWarning: invalid escape sequence '\e'
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
(EngineCore_DP0 pid=437) INFO 04-13 19:52:07 [core.py:101] Initializing a V1 LLM engine (v0.17.0) with config: model='/data/model', speculative_config=None, tokenizer='/data/model', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=auto, tensor_parallel_size=2, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=mxfp4, enforce_eager=False, enable_return_routed_experts=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='openai_gptoss', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=gpt-oss-120b-elliott, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'level': None, 'mode': <CompilationMode.VLLM_COMPILE: 3>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['none'], 'splitting_ops': ['vllm::unified_attention', 'vllm::unified_attention_with_output', 'vllm::unified_mla_attention', 'vllm::unified_mla_attention_with_output', 'vllm::mamba_mixer2', 'vllm::mamba_mixer', 'vllm::short_conv', 'vllm::linear_attention', 'vllm::plamo2_mamba_mixer', 'vllm::gdn_attention_core', 'vllm::kda_attention', 'vllm::sparse_attn_indexer', 'vllm::rocm_aiter_sparse_attn_indexer', 'vllm::unified_kv_cache_update', 'vllm::unified_mla_kv_cache_update'], 'compile_mm_encoder': False, 'compile_sizes': [], 'compile_ranges_split_points': [8192], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.FULL_AND_PIECEWISE: (2, 1)>, 'cudagraph_num_of_warmups': 1, 'cudagraph_capture_sizes': [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512, 528, 544, 560, 576, 592, 608, 624, 640, 656, 672, 688, 704, 720, 736, 752, 768, 784, 800, 816, 832, 848, 864, 880, 896, 912, 928, 944, 960, 976, 992, 1008, 1024], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': False, 'fuse_act_quant': False, 'fuse_attn_quant': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': True}, 'max_cudagraph_capture_size': 1024, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []}
(EngineCore_DP0 pid=437) WARNING 04-13 19:52:07 [multiproc_executor.py:945] Reducing Torch parallelism from 32 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
(EngineCore_DP0 pid=437) INFO 04-13 19:52:07 [multiproc_executor.py:134] DP group leader: node_rank=0, node_rank_within_dp=0, master_addr=127.0.0.1, mq_connect_ip=192.168.145.12 (local), world_size=2, local_world_size=2
/usr/local/lib/python3.12/dist-packages/transformers/utils/hub.py:110: FutureWarning: Using TRANSFORMERS_CACHE is deprecated and will be removed in v5 of Transformers. Use HF_HOME instead.
warnings.warn(
/usr/local/lib/python3.12/dist-packages/transformers/utils/hub.py:110: FutureWarning: Using TRANSFORMERS_CACHE is deprecated and will be removed in v5 of Transformers. Use HF_HOME instead.
warnings.warn(
/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/chat_completion/protocol.py:346: SyntaxWarning: invalid escape sequence '\e'
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/chat_completion/protocol.py:346: SyntaxWarning: invalid escape sequence '\e'
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/completion/protocol.py:176: SyntaxWarning: invalid escape sequence '\e'
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/completion/protocol.py:176: SyntaxWarning: invalid escape sequence '\e'
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
(Worker pid=509) INFO 04-13 19:52:22 [parallel_state.py:1393] world_size=2 rank=1 local_rank=1 distributed_init_method=tcp://127.0.0.1:52915 backend=nccl
(Worker pid=508) INFO 04-13 19:52:22 [parallel_state.py:1393] world_size=2 rank=0 local_rank=0 distributed_init_method=tcp://127.0.0.1:52915 backend=nccl
(Worker pid=508) :1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
(Worker pid=509) :1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
(Worker pid=509) :1301: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.
(Worker pid=508) :1301: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.
(Worker pid=508) INFO 04-13 19:52:23 [pynccl.py:111] vLLM is using nccl==2.27.5
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] WorkerProc failed to start.
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] Traceback (most recent call last):
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 771, in worker_main
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] worker = WorkerProc(*args, **kwargs)
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] return func(*args, **kwargs)
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 592, in init
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.worker.init_device()
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/worker_base.py", line 326, in init_device
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.worker.init_device() # type: ignore
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] return func(*args, **kwargs)
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 265, in init_device
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] init_worker_distributed_environment(
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 945, in init_worker_distributed_environment
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ensure_model_parallel_initialized(
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 1747, in ensure_model_parallel_initialized
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] initialize_model_parallel(
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 1562, in initialize_model_parallel
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] _TP = init_model_parallel_group(
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 1152, in init_model_parallel_group
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] return GroupCoordinator(
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 373, in init
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.device_communicator = device_comm_cls(
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/cuda_communicator.py", line 75, in init
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.pynccl_comm = PyNcclCommunicator(
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl.py", line 139, in init
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 407, in ncclCommInitRank
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.NCCL_CHECK(
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 373, in NCCL_CHECK
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] raise RuntimeError(f"NCCL error: {error_str}")
(Worker pid=508) ERROR 04-13 19:52:24 [multiproc_executor.py:800] RuntimeError: NCCL error: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO Bootstrap: Using eth0:192.168.145.12<0>
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO cudaDriverVersion 12090
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO NCCL version 2.27.5+cuda12.9
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO NET/Plugin: Could not find: libnccl-net.so.
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO NET/IB : No device found.
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO NET/IB : Using [RO]; OOB eth0:192.168.145.12<0>
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO NET/Socket : Using [0]eth0:192.168.145.12<0>
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO Initialized NET plugin Socket
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO Assigned NET plugin Socket to comm
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO Using network Socket

[2026-04-13 19:52:24] pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] init.cc:426 NCCL WARN Cuda failure 'operation not supported'
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO init.cc:1437 -> 1
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO init.cc:1832 -> 1
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:508:508 [0] NCCL INFO init.cc:1858 -> 1
(Worker pid=508) INFO 04-13 19:52:24 [multiproc_executor.py:749] Parent process exited, terminating worker
(Worker pid=509) INFO 04-13 19:52:24 [multiproc_executor.py:749] Parent process exited, terminating worker
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] WorkerProc failed to start.
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] Traceback (most recent call last):
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 771, in worker_main
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] worker = WorkerProc(*args, **kwargs)
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] return func(*args, **kwargs)
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 592, in init
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.worker.init_device()
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/worker_base.py", line 326, in init_device
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.worker.init_device() # type: ignore
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] return func(*args, **kwargs)
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 265, in init_device
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] init_worker_distributed_environment(
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 945, in init_worker_distributed_environment
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ensure_model_parallel_initialized(
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 1747, in ensure_model_parallel_initialized
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] initialize_model_parallel(
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 1562, in initialize_model_parallel
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] _TP = init_model_parallel_group(
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 1152, in init_model_parallel_group
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] return GroupCoordinator(
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 373, in init
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.device_communicator = device_comm_cls(
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/cuda_communicator.py", line 75, in init
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.pynccl_comm = PyNcclCommunicator(
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl.py", line 139, in init
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 407, in ncclCommInitRank
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] self.NCCL_CHECK(
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 373, in NCCL_CHECK
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] raise RuntimeError(f"NCCL error: {error_str}")
(Worker pid=509) ERROR 04-13 19:52:24 [multiproc_executor.py:800] RuntimeError: NCCL error: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO cudaDriverVersion 12090
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO Bootstrap: Using eth0:192.168.145.12<0>
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO NCCL version 2.27.5+cuda12.9
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO NET/Plugin: Could not find: libnccl-net.so.
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO NET/IB : No device found.
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO NET/IB : Using [RO]; OOB eth0:192.168.145.12<0>
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO NET/Socket : Using [0]eth0:192.168.145.12<0>
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO Initialized NET plugin Socket
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO Assigned NET plugin Socket to comm
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO Using network Socket

[2026-04-13 19:52:24] pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] init.cc:426 NCCL WARN Cuda failure 'operation not supported'
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO init.cc:1437 -> 1
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO init.cc:1832 -> 1
pais-modelendpoint-c3dea45f-4e72-4a34-8888-7c855e768973-8579zt8:509:509 [1] NCCL INFO init.cc:1858 -> 1
[rank0]:[W413 19:52:25.475343635 ProcessGroupNCCL.cpp:1553] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] EngineCore failed to start.
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] Traceback (most recent call last):
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 1090, in run_engine_core
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] return func(*args, **kwargs)
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 834, in init
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] super().init(
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 110, in init
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] self.model_executor = executor_class(vllm_config)
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 100, in init
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] super().init(vllm_config)
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] return func(*args, **kwargs)
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 103, in init
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] self._init_executor()
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 180, in _init_executor
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] self.workers = WorkerProc.wait_for_ready(unready_workers)
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 697, in wait_for_ready
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] raise e from None
(EngineCore_DP0 pid=437) ERROR 04-13 19:52:26 [core.py:1100] Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
(EngineCore_DP0 pid=437) Process EngineCore_DP0:
(EngineCore_DP0 pid=437) Traceback (most recent call last):
(EngineCore_DP0 pid=437) File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_DP0 pid=437) self.run()
(EngineCore_DP0 pid=437) File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_DP0 pid=437) self._target(*self._args, **self._kwargs)
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 1104, in run_engine_core
(EngineCore_DP0 pid=437) raise e
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 1090, in run_engine_core
(EngineCore_DP0 pid=437) engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
(EngineCore_DP0 pid=437) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore_DP0 pid=437) return func(*args, **kwargs)
(EngineCore_DP0 pid=437) ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 834, in init
(EngineCore_DP0 pid=437) super().init(
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 110, in init
(EngineCore_DP0 pid=437) self.model_executor = executor_class(vllm_config)
(EngineCore_DP0 pid=437) ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 100, in init
(EngineCore_DP0 pid=437) super().init(vllm_config)
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore_DP0 pid=437) return func(*args, **kwargs)
(EngineCore_DP0 pid=437) ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 103, in init
(EngineCore_DP0 pid=437) self._init_executor()
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 180, in _init_executor
(EngineCore_DP0 pid=437) self.workers = WorkerProc.wait_for_ready(unready_workers)
(EngineCore_DP0 pid=437) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=437) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 697, in wait_for_ready
(EngineCore_DP0 pid=437) raise e from None
(EngineCore_DP0 pid=437) Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
(APIServer pid=355) Traceback (most recent call last):
(APIServer pid=355) File "", line 198, in _run_module_as_main
(APIServer pid=355) File "", line 88, in _run_code
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 545, in
(APIServer pid=355) uvloop.run(run_server(args))
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/uvloop/init.py", line 96, in run
(APIServer pid=355) return __asyncio.run(
(APIServer pid=355) ^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/lib/python3.12/asyncio/runners.py", line 195, in run
(APIServer pid=355) return runner.run(main)
(APIServer pid=355) ^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
(APIServer pid=355) return self._loop.run_until_complete(task)
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/uvloop/init.py", line 48, in wrapper
(APIServer pid=355) return await main
(APIServer pid=355) ^^^^^^^^^^
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 471, in run_server
(APIServer pid=355) await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 490, in run_server_worker
(APIServer pid=355) async with build_async_engine_client(
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/lib/python3.12/contextlib.py", line 210, in aenter
(APIServer pid=355) return await anext(self.gen)
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 96, in build_async_engine_client
(APIServer pid=355) async with build_async_engine_client_from_engine_args(
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/lib/python3.12/contextlib.py", line 210, in aenter
(APIServer pid=355) return await anext(self.gen)
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 137, in build_async_engine_client_from_engine_args
(APIServer pid=355) async_llm = AsyncLLM.from_vllm_config(
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 225, in from_vllm_config
(APIServer pid=355) return cls(
(APIServer pid=355) ^^^^
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 154, in init
(APIServer pid=355) self.engine_core = EngineCoreClient.make_async_mp_client(
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(APIServer pid=355) return func(*args, **kwargs)
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 127, in make_async_mp_client
(APIServer pid=355) return AsyncMPClient(*client_args)
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/tracing/otel.py", line 178, in sync_wrapper
(APIServer pid=355) return func(*args, **kwargs)
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 911, in init
(APIServer pid=355) super().init(
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 569, in init
(APIServer pid=355) with launch_core_engines(
(APIServer pid=355) ^^^^^^^^^^^^^^^^^^^^
(APIServer pid=355) File "/usr/lib/python3.12/contextlib.py", line 144, in exit
(APIServer pid=355) next(self.gen)
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/utils.py", line 951, in launch_core_engines
(APIServer pid=355) wait_for_engine_startup(
(APIServer pid=355) File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/utils.py", line 1010, in wait_for_engine_startup
(APIServer pid=355) raise RuntimeError(
(APIServer pid=355) RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}
/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 2 leaked shared_memory objects to clean up at shutdown

Steps to Reproduce the Issue

  1. Provision a VMware VM with 2x H100 vGPUs using the H100XM-80C profile.
  2. Ensure UVM is disabled at the hypervisor level (do not set pciPassthruX.cfg.enable_uvm=1).
  3. Install PyTorch with NCCL 2.27.5.
  4. Run the following minimal reproduction script:
  import os
  import torch
  import torch.distributed as dist
  import torch.multiprocessing as mp
  def run_worker(rank, world_size):
      os.environ["MASTER_ADDR"] = "127.0.0.1"
      os.environ["MASTER_PORT"] = "29501"
      dist.init_process_group("nccl", rank=rank, world_size=world_size)

      torch.cuda.set_device(rank)
      tensor = torch.ones(10).cuda(rank)
      dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
      print(f"[Rank {rank}] SUCCESS!")
  if __name__ == "__main__":
      mp.spawn(run_worker, args=(2,), nprocs=2, join=True)
  1. Execute with: NCCL_CUMEM_ENABLE=0 NCCL_CUMEM_HOST_ENABLE=0 NCCL_DEBUG=INFO NCCL_DEBUG_SUBSYS=INIT python3 test.py

NCCL Version

2.27.5

Your platform details

• Hypervisor: VMware ESXi (vSphere 8/9)
• GPU Hardware: 2x NVIDIA H100 (NV18 connectivity confirmed via nvidia-smi topo -m)
• vGPU Profile: H100XM-80C
• Addressing Mode: HMM (UVM disabled)
• Guest OS: Ubuntu 22.04
• CUDA Driver: 580.x / CUDA 12.8
• NCCL Version: 2.27.5 (Failing) / 2.21.5 (Working)
• Framework: PyTorch 2.10.0 / vLLM 0.19.0

Error Message & Behavior

Expected Behavior

NCCL should respect the NCCL_CUMEM_ENABLE=0 and NCCL_CUMEM_HOST_ENABLE=0 environment variables, bypass the cuMem / memory pool APIs that require UVM, and successfully initialize using legacy memory allocation paths (as it did in NCCL
2.21.5).

Actual Behavior:

The process crashes immediately during ring initialization:

gpu-debug:936:1010 [1] NCCL INFO cudaDriverVersion 12080
gpu-debug:936:1010 [1] NCCL INFO NCCL version 2.27.5+cuda12.9
...
[2026-04-14 14:29:12] gpu-debug:936:1010 [1] init.cc:426 NCCL WARN Cuda failure 'operation not supported'
torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/NCCLUtils.cpp:93, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.27.5
ncclUnhandledCudaError: Call to CUDA function failed.
Last error:
Cuda failure 'operation not supported'

Troubleshooting & Workarounds Attempted:

  1. Verified Hardware P2P: Ran torch.cuda.can_device_access_peer(0, 1) which returns True. Direct tensor copying (t1.copy_(t0)) succeeds. Hardware NVLink is functional.
  2. Downgrade NCCL: Downgrading to NCCL 2.21.5 makes the script run successfully. However, this is not a viable long-term workaround as modern frameworks (vLLM 0.18+, PyTorch 2.9+) require newer NCCL versions.
  3. Environment Variables: Attempted setting NCCL_P2P_DISABLE=1, NCCL_SHM_DISABLE=1, NCCL_CUMEM_ENABLE=0, and NCCL_CUMEM_HOST_ENABLE=0. None of these bypass the failing CUDA call at init.cc:426.
  4. Enable UVM: Adding pciPassthruX.cfg.enable_uvm=1 to the ESXi VM configuration resolves the NCCL crash, but this explicitly disables VMware vMotion ("vGPU migration is not supported on this VM"), which violates our enterprise
    high-availability requirements.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions