fix metax & update ppu#1150
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the parallel environment initialization in lightx2v_platform/base/metax_cuda.py to use torch.cuda.set_device instead of torch.npu.set_device. Feedback points out that using the global rank (dist.get_rank()) to set the CUDA device will cause failures in multi-node distributed training, and suggests using the local rank from the LOCAL_RANK environment variable instead.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def init_parallel_env(): | ||
| dist.init_process_group(backend="nccl") | ||
| torch.npu.set_device(dist.get_rank()) | ||
| torch.cuda.set_device(dist.get_rank()) |
There was a problem hiding this comment.
Using dist.get_rank() to set the CUDA device will cause failures in multi-node distributed training. dist.get_rank() returns the global rank, which can exceed the number of GPUs available on a single node (e.g., rank 8 on a second 8-GPU node). Instead, you should use the local rank (typically retrieved from the LOCAL_RANK environment variable) to set the device.
| torch.cuda.set_device(dist.get_rank()) | |
| import os | |
| torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) |
No description provided.