Skip to content

fix: handle empty samples and add CUDA fallback#182

Open
mshzy wants to merge 1 commit into
SesameAILabs:mainfrom
mshzy:fix/empty-samples-cuda
Open

fix: handle empty samples and add CUDA fallback#182
mshzy wants to merge 1 commit into
SesameAILabs:mainfrom
mshzy:fix/empty-samples-cuda

Conversation

@mshzy
Copy link
Copy Markdown

@mshzy mshzy commented Jun 2, 2026

Fixes: 1) torch.stack crash when no samples generated, 2) load_csm_1b crash on CPU-only systems

- generator.generate(): return zero tensor when no audio samples generated
- load_csm_1b(): auto-fallback to CPU when CUDA unavailable
Copy link
Copy Markdown
Collaborator

@ZackHodari ZackHodari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice fix

Small cleanup and a question about dtypes

Comment thread generator.py
warnings.warn("CUDA not available, falling back to CPU")
model = Model.from_pretrained("sesame/csm-1b")
model.to(device=device, dtype=torch.bfloat16)
model.to(device=device, dtype=torch.bfloat16 if device == "cuda" else torch.float32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bfloat16 should work on CPU, please share more details if you tested it and found it doesn't work

Comment thread generator.py
def load_csm_1b(device: str = "cuda") -> Generator:
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
import warnings
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move import to top level

Comment thread generator.py

if not samples:
# No audio generated (e.g., empty prompt or immediate EOS)
return torch.zeros(1, device=self.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer it returns a tensor with 0 samples, we shouldn't make up data

Suggested change
return torch.zeros(1, device=self.device)
return torch.zeros(0, device=self.device)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants